diff --git a/context.go b/context.go new file mode 100644 index 0000000..c2a81d0 --- /dev/null +++ b/context.go @@ -0,0 +1,211 @@ +package graphvent + +import ( + "github.com/graphql-go/graphql" + badger "github.com/dgraph-io/badger/v3" + "reflect" + "fmt" +) + +// For persistance, each node needs the following functions(* is a placeholder for the node/state type): +// Load*State - StateLoadFunc that returns the NodeState interface to attach to the node +// Load* - NodeLoadFunc that returns the GraphNode restored from it's loaded state + +// For convenience, the following functions are a good idea to define for composability: +// Restore*State - takes in the nodes serialized data to allow for easier nesting of inherited Load*State functions +// Save*State - serialize the node into it's json counterpart to be included as part of a larger json + +type NodeLoadFunc func(*Context, NodeID, []byte, NodeMap)(Node, error) +type NodeDef struct { + Load NodeLoadFunc + Type NodeType +} + +func NewNodeDef(type_name string, load_func NodeLoadFunc) NodeDef { + return NodeDef{ + Type: NodeType(type_name), + Load: load_func, + } +} + +type Context struct { + DB * badger.DB + Log Logger + Types map[uint64]NodeDef + GQL * GQLContext +} + +func (ctx * Context) RegisterNodeType(type_name string, load_func NodeLoadFunc) error { + if load_func == nil { + return fmt.Errorf("Cannot register a node without a load function") + } + + def := NodeDef{ + Type: NodeType(type_name), + Load: load_func, + } + + type_hash := def.Type.Hash() + _, exists := ctx.Types[type_hash] + if exists == true { + return fmt.Errorf("Cannot register node of type %s, type already exists in context", type_name) + } + + ctx.Types[type_hash] = def + return nil +} + +type TypeList []graphql.Type +type ObjTypeMap map[reflect.Type]*graphql.Object +type FieldMap map[string]*graphql.Field + +type GQLContext struct { + Schema graphql.Schema + ValidNodes ObjTypeMap + NodeType reflect.Type + ValidLockables ObjTypeMap + LockableType reflect.Type + ValidThreads ObjTypeMap + ThreadType reflect.Type +} + +/*func NewGQLContext(additional_types TypeList, extended_types ObjTypeMap, extended_queries FieldMap, extended_subscriptions FieldMap, extended_mutations FieldMap) (*GQLContext, error) { + type_list := TypeList{ + GQLTypeSignalInput(), + } + + for _, gql_type := range(additional_types) { + type_list = append(type_list, gql_type) + } + + type_map := ObjTypeMap{} + type_map[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable() + type_map[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread() + type_map[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread() + type_map[reflect.TypeOf((*BaseSignal)(nil))] = GQLTypeSignal() + + for go_t, gql_t := range(extended_types) { + type_map[go_t] = gql_t + } + + valid_nodes := ObjTypeMap{} + valid_lockables := ObjTypeMap{} + valid_threads := ObjTypeMap{} + + node_type := reflect.TypeOf((*GraphNode)(nil)).Elem() + lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() + thread_type := reflect.TypeOf((*Thread)(nil)).Elem() + + for go_t, gql_t := range(type_map) { + if go_t.Implements(node_type) { + valid_nodes[go_t] = gql_t + } + if go_t.Implements(lockable_type) { + valid_lockables[go_t] = gql_t + } + if go_t.Implements(thread_type) { + valid_threads[go_t] = gql_t + } + type_list = append(type_list, gql_t) + } + + queries := graphql.Fields{ + "Self": GQLQuerySelf(), + } + + for key, val := range(extended_queries) { + queries[key] = val + } + + subscriptions := graphql.Fields{ + "Update": GQLSubscriptionUpdate(), + "Self": GQLSubscriptionSelf(), + } + + for key, val := range(extended_subscriptions) { + subscriptions[key] = val + } + + mutations := graphql.Fields{ + "SendUpdate": GQLMutationSendUpdate(), + } + + for key, val := range(extended_mutations) { + mutations[key] = val + } + + schemaConfig := graphql.SchemaConfig{ + Types: type_list, + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: queries, + }), + Mutation: graphql.NewObject(graphql.ObjectConfig{ + Name: "Mutation", + Fields: mutations, + }), + Subscription: graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: subscriptions, + }), + } + + schema, err := graphql.NewSchema(schemaConfig) + if err != nil{ + return nil, err + } + + ctx := GQLContext{ + Schema: schema, + ValidNodes: valid_nodes, + NodeType: node_type, + ValidThreads: valid_threads, + ThreadType: thread_type, + ValidLockables: valid_lockables, + LockableType: lockable_type, + } + + return &ctx, nil +}*/ + +func NewContext(db * badger.DB, log Logger, extra_nodes map[string]NodeLoadFunc, types TypeList, type_map ObjTypeMap, queries FieldMap, subscriptions FieldMap, mutations FieldMap) * Context { + /*gql, err := NewGQLContext(types, type_map, queries, subscriptions, mutations) + if err != nil { + panic(err) + }*/ + + ctx := &Context{ + GQL: nil, + DB: db, + Log: log, + Types: map[uint64]NodeDef{}, + } + + + + err := ctx.RegisterNodeType("graph_node", LoadGraphNode) + if err != nil { + panic(err) + } + err = ctx.RegisterNodeType("simple_lockable", LoadSimpleLockable) + if err != nil { + panic(err) + } + /*err := ctx.RegisterNodeType("simple_thread", LoadSimpleThread) + if err != nil { + panic(err) + } + err := ctx.RegisterNodeType("gql_thread", LoadGQLThread) + if err != nil { + panic(err) + }*/ + + for name, load_fn := range(extra_nodes) { + err := ctx.RegisterNodeType(name, load_fn) + if err != nil { + panic(err) + } + } + + return ctx +} diff --git a/graph.go b/graph.go deleted file mode 100644 index 3611392..0000000 --- a/graph.go +++ /dev/null @@ -1,735 +0,0 @@ -package graphvent - -import ( - "sync" - "reflect" - "github.com/google/uuid" - "github.com/graphql-go/graphql" - "os" - "github.com/rs/zerolog" - "fmt" - badger "github.com/dgraph-io/badger/v3" - "encoding/json" -) - -// For persistance, each node needs the following functions(* is a placeholder for the node/state type): -// Load*State - StateLoadFunc that returns the NodeState interface to attach to the node -// Load* - NodeLoadFunc that returns the GraphNode restored from it's loaded state - -// For convenience, the following functions are a good idea to define for composability: -// Restore*State - takes in the nodes serialized data to allow for easier nesting of inherited Load*State functions -// Save*State - serialize the node into it's json counterpart to be included as part of a larger json - -type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error) -type StateLoadMap map[string]StateLoadFunc -type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) -type NodeLoadMap map[string]NodeLoadFunc -type InfoLoadFunc func(*GraphContext, map[string]interface{})(ThreadInfo, error) -type InfoLoadMap map[string]InfoLoadFunc -type GraphContext struct { - DB * badger.DB - Log Logger - NodeLoadFuncs NodeLoadMap - StateLoadFuncs StateLoadMap - InfoLoadFuncs InfoLoadMap - GQL * GQLContext -} - -type GQLContext struct { - Schema graphql.Schema - ValidNodes ObjTypeMap - NodeType reflect.Type - ValidLockables ObjTypeMap - LockableType reflect.Type - ValidThreads ObjTypeMap - ThreadType reflect.Type -} - -func NewGQLContext(additional_types TypeList, extended_types ObjTypeMap, extended_queries FieldMap, extended_subscriptions FieldMap, extended_mutations FieldMap) (*GQLContext, error) { - type_list := TypeList{ - GQLTypeSignalInput(), - } - - for _, gql_type := range(additional_types) { - type_list = append(type_list, gql_type) - } - - type_map := ObjTypeMap{} - type_map[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable() - type_map[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread() - type_map[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread() - type_map[reflect.TypeOf((*BaseSignal)(nil))] = GQLTypeSignal() - - for go_t, gql_t := range(extended_types) { - type_map[go_t] = gql_t - } - - valid_nodes := ObjTypeMap{} - valid_lockables := ObjTypeMap{} - valid_threads := ObjTypeMap{} - - node_type := reflect.TypeOf((*GraphNode)(nil)).Elem() - lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() - thread_type := reflect.TypeOf((*Thread)(nil)).Elem() - - for go_t, gql_t := range(type_map) { - if go_t.Implements(node_type) { - valid_nodes[go_t] = gql_t - } - if go_t.Implements(lockable_type) { - valid_lockables[go_t] = gql_t - } - if go_t.Implements(thread_type) { - valid_threads[go_t] = gql_t - } - type_list = append(type_list, gql_t) - } - - queries := graphql.Fields{ - "Self": GQLQuerySelf(), - } - - for key, val := range(extended_queries) { - queries[key] = val - } - - subscriptions := graphql.Fields{ - "Update": GQLSubscriptionUpdate(), - "Self": GQLSubscriptionSelf(), - } - - for key, val := range(extended_subscriptions) { - subscriptions[key] = val - } - - mutations := graphql.Fields{ - "SendUpdate": GQLMutationSendUpdate(), - } - - for key, val := range(extended_mutations) { - mutations[key] = val - } - - schemaConfig := graphql.SchemaConfig{ - Types: type_list, - Query: graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: queries, - }), - Mutation: graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: mutations, - }), - Subscription: graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: subscriptions, - }), - } - - schema, err := graphql.NewSchema(schemaConfig) - if err != nil{ - return nil, err - } - - ctx := GQLContext{ - Schema: schema, - ValidNodes: valid_nodes, - NodeType: node_type, - ValidThreads: valid_threads, - ThreadType: thread_type, - ValidLockables: valid_lockables, - LockableType: lockable_type, - } - - return &ctx, nil -} - -func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { - // Initialize an empty list of loaded nodes, then start loading them from id - loaded_nodes := map[NodeID]GraphNode{} - return LoadNodeRecurse(ctx, id, loaded_nodes) -} - -type DBJSONBase struct { - Type string `json:"type"` -} - -// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded -// Call the node load function related to the type, which will call this parse function recusively as needed -func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) { - node, exists := loaded_nodes[id] - if exists == false { - state_bytes, err := ReadDBState(ctx, id) - if err != nil { - return nil, err - } - - var base DBJSONBase - err = json.Unmarshal(state_bytes, &base) - if err != nil { - return nil, err - } - - ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id) - - node_fn, exists := ctx.NodeLoadFuncs[base.Type] - if exists == false { - return nil, fmt.Errorf("%s is not a known node type: %s", base.Type, state_bytes) - } - - node, err = node_fn(ctx, id) - if err != nil { - return nil, err - } - - loaded_nodes[id] = node - - state_fn, exists := ctx.StateLoadFuncs[base.Type] - if exists == false { - return nil, fmt.Errorf("%s is not a known node state type", base.Type) - } - - state, err := state_fn(ctx, state_bytes, loaded_nodes) - if err != nil { - return nil, err - } - - node.SetState(state) - } - return node, nil -} - -func NewGraphContext(db * badger.DB, log Logger, state_loads StateLoadMap, node_loads NodeLoadMap, info_loads InfoLoadMap, types TypeList, type_map ObjTypeMap, queries FieldMap, subscriptions FieldMap, mutations FieldMap) * GraphContext { - gql, err := NewGQLContext(types, type_map, queries, subscriptions, mutations) - if err != nil { - panic(err) - } - - ctx := GraphContext{ - GQL: gql, - DB: db, - Log: log, - NodeLoadFuncs: NodeLoadMap{ - "simple_lockable": LoadSimpleLockable, - "simple_thread": LoadSimpleThread, - "gql_thread": LoadGQLThread, - }, - StateLoadFuncs: StateLoadMap{ - "simple_lockable": LoadSimpleLockableState, - "simple_thread": LoadSimpleThreadState, - "gql_thread": LoadGQLThreadState, - }, - InfoLoadFuncs: InfoLoadMap{ - "gql_thread": LoadGQLThreadInfo, - }, - } - - for name, fn := range(state_loads) { - ctx.StateLoadFuncs[name] = fn - } - - for name, fn := range(node_loads) { - ctx.NodeLoadFuncs[name] = fn - } - - for name, fn := range(info_loads) { - ctx.InfoLoadFuncs[name] = fn - } - - return &ctx -} - -// A Logger is passed around to record events happening to components enabled by SetComponents -type Logger interface { - SetComponents(components []string) error - // Log a formatted string - Logf(component string, format string, items ... interface{}) - // Log a map of attributes and a format string - Logm(component string, fields map[string]interface{}, format string, items ... interface{}) - // Log a structure to a file by marshalling and unmarshalling the json - Logj(component string, s interface{}, format string, items ... interface{}) -} - -func NewConsoleLogger(components []string) *ConsoleLogger { - logger := &ConsoleLogger{ - loggers: map[string]zerolog.Logger{}, - components: []string{}, - } - - logger.SetComponents(components) - - return logger -} - -// A ConsoleLogger logs to stdout -type ConsoleLogger struct { - loggers map[string]zerolog.Logger - components_lock sync.Mutex - components []string -} - -func (logger * ConsoleLogger) SetComponents(components []string) error { - logger.components_lock.Lock() - defer logger.components_lock.Unlock() - - component_enabled := func (component string) bool { - for _, c := range(components) { - if c == component { - return true - } - } - return false - } - - for c, _ := range(logger.loggers) { - if component_enabled(c) == false { - delete(logger.loggers, c) - } - } - - for _, c := range(components) { - _, exists := logger.loggers[c] - if component_enabled(c) == true && exists == false { - logger.loggers[c] = zerolog.New(os.Stdout).With().Timestamp().Str("component", c).Logger() - } - } - return nil -} - -func (logger * ConsoleLogger) Logm(component string, fields map[string]interface{}, format string, items ... interface{}) { - l, exists := logger.loggers[component] - if exists == true { - log := l.Log() - for key, value := range(fields) { - log = log.Str(key, fmt.Sprintf("%+v", value)) - } - log.Msg(fmt.Sprintf(format, items...)) - } -} - -func (logger * ConsoleLogger) Logf(component string, format string, items ... interface{}) { - l, exists := logger.loggers[component] - if exists == true { - l.Log().Msg(fmt.Sprintf(format, items...)) - } -} - -func (logger * ConsoleLogger) Logj(component string, s interface{}, format string, items ... interface{}) { - m := map[string]interface{}{} - ser, err := json.Marshal(s) - if err != nil { - panic("LOG_MARSHAL_ERR") - } - err = json.Unmarshal(ser, &m) - if err != nil { - panic("LOG_UNMARSHAL_ERR") - } - logger.Logm(component, m, format, items...) -} - -type NodeID string -// Generate a random id -func RandID() NodeID { - uuid_str := uuid.New().String() - return NodeID(uuid_str) -} - -type SignalDirection int -const ( - Up SignalDirection = iota - Down - Direct -) - -// GraphSignals are passed around the event tree/resource DAG and cast by Type() -type GraphSignal interface { - // How to propogate the signal - Direction() SignalDirection - Source() NodeID - Type() string - String() string -} - -// BaseSignal is the most basic type of signal, it has no additional data -type BaseSignal struct { - FDirection SignalDirection `json:"direction"` - FSource NodeID `json:"source"` - FType string `json:"type"` -} - -func (state BaseSignal) String() string { - ser, err := json.Marshal(state) - if err != nil { - return "STATE_SER_ERR" - } - return string(ser) -} - -func (signal BaseSignal) Direction() SignalDirection { - return signal.FDirection -} - -func (signal BaseSignal) Source() NodeID { - return signal.FSource -} - -func (signal BaseSignal) Type() string { - return signal.FType -} - -func NewBaseSignal(source GraphNode, _type string, direction SignalDirection) BaseSignal { - var source_id NodeID = "nil" - if source != nil { - source_id = source.ID() - } - - signal := BaseSignal{ - FDirection: direction, - FSource: source_id, - FType: _type, - } - return signal -} - -func NewDownSignal(source GraphNode, _type string) BaseSignal { - return NewBaseSignal(source, _type, Down) -} - -func NewSignal(source GraphNode, _type string) BaseSignal { - return NewBaseSignal(source, _type, Up) -} - -func NewDirectSignal(source GraphNode, _type string) BaseSignal { - return NewBaseSignal(source, _type, Direct) -} - -func AbortSignal(source GraphNode) BaseSignal { - return NewBaseSignal(source, "abort", Down) -} - -func CancelSignal(source GraphNode) BaseSignal { - return NewBaseSignal(source, "cancel", Down) -} - -type NodeState interface { - // Human-readable name of the node, not guaranteed to be unique - Name() string - // Type of the node this state is attached to. Used to deserialize the state to a node from the database - Type() string -} - -// GraphNode is the interface common to both DAG nodes and Event tree nodes -// They have a NodeState interface which is saved to the database every update -type GraphNode interface { - ID() NodeID - - State() NodeState - StateLock() *sync.RWMutex - - SetState(new_state NodeState) - - // Signal propagation function for listener channels - UpdateListeners(ctx * GraphContext, update GraphSignal) - // Signal propagation function for connected nodes(defined in state) - PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) - - // Get an update channel for the node to be notified of signals - UpdateChannel(buffer int) chan GraphSignal - - // Register and unregister a channel to propogate updates to - RegisterChannel(listener chan GraphSignal) - UnregisterChannel(listener chan GraphSignal) - // Get a handle to the nodes internal signal channel - SignalChannel() chan GraphSignal -} - -const NODE_SIGNAL_BUFFER = 256 - -func RestoreNode(ctx * GraphContext, id NodeID) BaseNode { - node := BaseNode{ - id: id, - signal: make(chan GraphSignal, NODE_SIGNAL_BUFFER), - listeners: map[chan GraphSignal]chan GraphSignal{}, - listeners_lock: &sync.Mutex{}, - state: nil, - state_lock: &sync.RWMutex{}, - } - - ctx.Log.Logf("graph", "RESTORE_NODE: %s", node.id) - return node -} - -func WriteDBState(ctx * GraphContext, id NodeID, state NodeState) error { - ctx.Log.Logf("db", "DB_WRITE: %s - %+v", id, state) - - var serialized_state []byte = nil - if state != nil { - ser, err := json.Marshal(state) - if err != nil { - return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) - } - serialized_state = ser - } else { - serialized_state = []byte{} - } - - err := ctx.DB.Update(func(txn *badger.Txn) error { - err := txn.Set([]byte(id), serialized_state) - return err - }) - - return err -} - -// Create a new base node with a new ID -func NewNode(ctx * GraphContext, state NodeState) (BaseNode, error) { - node := BaseNode{ - id: RandID(), - signal: make(chan GraphSignal, NODE_SIGNAL_BUFFER), - listeners: map[chan GraphSignal]chan GraphSignal{}, - listeners_lock: &sync.Mutex{}, - state: state, - state_lock: &sync.RWMutex{}, - } - - err := WriteDBState(ctx, node.id, state) - if err != nil { - return node, fmt.Errorf("DB_NEW_WRITE_ERROR: %e", err) - } - - ctx.Log.Logf("graph", "NEW_NODE: %s - %+v", node.id, state) - return node, nil -} - -// BaseNode is the minimum set of fields needed to implement a GraphNode, -// and provides a template for more complicated Nodes -type BaseNode struct { - id NodeID - - state NodeState - state_lock *sync.RWMutex - - signal chan GraphSignal - - listeners_lock *sync.Mutex - listeners map[chan GraphSignal]chan GraphSignal -} - -func (node * BaseNode) ID() NodeID { - return node.id -} - -func (node * BaseNode) State() NodeState { - return node.state -} - -func (node * BaseNode) StateLock() * sync.RWMutex { - return node.state_lock -} - -func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) { - var bytes []byte - err := ctx.DB.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(id)) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - bytes = append([]byte{}, val...) - return nil - }) - }) - - if err != nil { - ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err) - return nil, err - } - - ctx.Log.Logf("db", "DB_READ: %s - %s", id, string(bytes)) - - return bytes, nil -} - -func WriteDBStates(ctx * GraphContext, nodes NodeMap) error{ - ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) - serialized_states := map[NodeID][]byte{} - for _, node := range(nodes) { - ser, err := json.Marshal(node.State()) - if err != nil { - return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) - } - serialized_states[node.ID()] = ser - } - - err := ctx.DB.Update(func(txn *badger.Txn) error { - i := 0 - for id, _ := range(nodes) { - ctx.Log.Logf("db", "DB_WRITE: %s - %s", id, string(serialized_states[id])) - err := txn.Set([]byte(id), serialized_states[id]) - if err != nil { - return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) - } - i++ - } - return nil - }) - return err -} - -func (node * BaseNode) SetState(new_state NodeState) { - node.state = new_state -} - -func checkForDuplicate(nodes []GraphNode) error { - found := map[NodeID]bool{} - for _, node := range(nodes) { - if node == nil { - return fmt.Errorf("Cannot get state of nil node") - } - - _, exists := found[node.ID()] - if exists == true { - return fmt.Errorf("Attempted to get state of %s twice", node.ID()) - } - found[node.ID()] = true - } - return nil -} - -func NodeList[K GraphNode](list []K) []GraphNode { - nodes := make([]GraphNode, len(list)) - for i, node := range(list) { - nodes[i] = node - } - return nodes -} - -type NodeStateMap map[NodeID]NodeState -type NodeMap map[NodeID]GraphNode -type StatesFn func(states NodeStateMap) error -type NodesFn func(nodes NodeMap) error -func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn StatesFn) error { - states := NodeStateMap{} - return UseMoreStates(ctx, nodes, states, states_fn) -} -func UseMoreStates(ctx * GraphContext, nodes []GraphNode, states NodeStateMap, states_fn StatesFn) error { - err := checkForDuplicate(nodes) - if err != nil { - return err - } - - locked_nodes := []GraphNode{} - for _, node := range(nodes) { - _, locked := states[node.ID()] - if locked == false { - node.StateLock().RLock() - states[node.ID()] = node.State() - locked_nodes = append(locked_nodes, node) - } - } - - err = states_fn(states) - - for _, node := range(locked_nodes) { - delete(states, node.ID()) - node.StateLock().RUnlock() - } - - return err -} - -func UpdateStates(ctx * GraphContext, nodes []GraphNode, nodes_fn NodesFn) error { - locked_nodes := NodeMap{} - err := UpdateMoreStates(ctx, nodes, locked_nodes, nodes_fn) - if err == nil { - err = WriteDBStates(ctx, locked_nodes) - } - - for _, node := range(locked_nodes) { - node.StateLock().Unlock() - } - return err -} -func UpdateMoreStates(ctx * GraphContext, nodes []GraphNode, locked_nodes NodeMap, nodes_fn NodesFn) error { - for _, node := range(nodes) { - _, locked := locked_nodes[node.ID()] - if locked == false { - node.StateLock().Lock() - locked_nodes[node.ID()] = node - } - } - - return nodes_fn(locked_nodes) -} - -func (node * BaseNode) UpdateListeners(ctx * GraphContext, update GraphSignal) { - node.listeners_lock.Lock() - defer node.listeners_lock.Unlock() - closed := []chan GraphSignal{} - - for _, listener := range node.listeners { - ctx.Log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.ID(), listener) - select { - case listener <- update: - default: - ctx.Log.Logf("listeners", "CLOSED_LISTENER %s: %p", node.ID(), listener) - go func(node GraphNode, listener chan GraphSignal) { - listener <- NewSignal(node, "listener_closed") - close(listener) - }(node, listener) - closed = append(closed, listener) - } - } - - for _, listener := range(closed) { - delete(node.listeners, listener) - } -} - -func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) { -} - -func (node * BaseNode) RegisterChannel(listener chan GraphSignal) { - node.listeners_lock.Lock() - _, exists := node.listeners[listener] - if exists == false { - node.listeners[listener] = listener - } - node.listeners_lock.Unlock() -} - -func (node * BaseNode) UnregisterChannel(listener chan GraphSignal) { - node.listeners_lock.Lock() - _, exists := node.listeners[listener] - if exists == false { - panic("Attempting to unregister non-registered listener") - } else { - delete(node.listeners, listener) - } - node.listeners_lock.Unlock() -} - -func (node * BaseNode) SignalChannel() chan GraphSignal { - return node.signal -} - -// Create a new GraphSinal channel with a buffer of size buffer and register it to a node -func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal { - new_listener := make(chan GraphSignal, buffer) - node.RegisterChannel(new_listener) - return new_listener -} - -// Propogate a signal starting at a node -func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal, states NodeStateMap) { - if node == nil { - panic("Cannot start an update from no node") - } - - ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal) - - node.UpdateListeners(ctx, signal) - node.PropagateUpdate(ctx, signal, states) -} - diff --git a/graph_test.go b/graph_test.go index 4be2b3c..558aef0 100644 --- a/graph_test.go +++ b/graph_test.go @@ -12,7 +12,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForValue(ctx * GraphContext, listener chan GraphSignal, signal_type string, source GraphNode, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, source Node, timeout time.Duration, str string) GraphSignal { timeout_channel := time.After(timeout) for true { select { @@ -52,22 +52,22 @@ func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { } } -func logTestContext(t * testing.T, components []string) * GraphContext { +func logTestContext(t * testing.T, components []string) * Context { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) if err != nil { t.Fatal(err) } - return NewGraphContext(db, NewConsoleLogger(components), StateLoadMap{}, NodeLoadMap{}, InfoLoadMap{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + return NewContext(db, NewConsoleLogger(components), map[string]NodeLoadFunc{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) } -func testContext(t * testing.T) * GraphContext { +func testContext(t * testing.T) * Context { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) if err != nil { t.Fatal(err) } - return NewGraphContext(db, NewConsoleLogger([]string{}), StateLoadMap{}, NodeLoadMap{}, InfoLoadMap{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + return NewContext(db, NewConsoleLogger([]string{}), map[string]NodeLoadFunc{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) } func fatalErr(t * testing.T, err error) { diff --git a/lockable.go b/lockable.go index 94f87a3..8697459 100644 --- a/lockable.go +++ b/lockable.go @@ -5,34 +5,48 @@ import ( "encoding/json" ) -// LockableState is the interface that any node that wants to posses locks must implement -// -// ReturnLock returns the node that held the lockable pointed to by ID before this node and -// removes the mapping from it's state, or nil if the lockable was unlocked previously -// -// AllowedToTakeLock returns true if the node pointed to by ID is allowed to take a lock from this node -// -type LockableState interface { - NodeState - - ReturnLock(lockable_id NodeID) Lockable - AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool - RecordLockHolder(lockable_id NodeID, lock_holder Lockable) - - Requirements() []Lockable +// A Lockable represents a Node that can be locked and hold other Nodes locks +type Lockable interface { + // All Lockable's are nodes + Node + //// State Modification Function + // Record that lockable was returned to it's owner and is no longer held by this Node + // Returns the previous owner of the lockable + RecordUnlock(lockable Lockable) Lockable + // Record that lockable was locked by this node, and that it should be returned to last_owner + RecordLock(lockable Lockable, last_owner Lockable) + // Link a requirement to this Node AddRequirement(requirement Lockable) + // Remove a requirement linked to this Node RemoveRequirement(requirement Lockable) - Dependencies() []Lockable + // Link a dependency to this Node AddDependency(dependency Lockable) + // Remove a dependency linked to this Node RemoveDependency(dependency Lockable) + // + SetOwner(new_owner Lockable) + + //// State Reading Functions + // Called when new_owner wants to take lockable's lock but it's owned by this node + // A true return value means that the lock can be passed + AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool + // Get all the linked requirements to this node + Requirements() []Lockable + // Get all the linked dependencies to this node + Dependencies() []Lockable + // Get the node's Owner Owner() Lockable - SetOwner(owner Lockable) + // Called during the lock process after locking the state and before updating the Node's state + // a non-nil return value will abort the lock attempt + CanLock(new_owner Lockable) error + // Called during the unlock process after locking the state and before updating the Node's state + // a non-nil return value will abort the unlock attempt + CanUnlock(old_owner Lockable) error } -// BaseLockableStates are a minimum collection of variables for a basic implementation of a LockHolder -// Include in any state structs that should be lockable -type BaseLockableState struct { - _type string +// SimpleLockable is a simple Lockable implementation that can be embedded into more complex structures +type SimpleLockable struct { + GraphNode name string owner Lockable requirements []Lockable @@ -40,8 +54,11 @@ type BaseLockableState struct { locks_held map[NodeID]Lockable } -type BaseLockableStateJSON struct { - Type string `json:"type"` +func (state * SimpleLockable) Type() NodeType { + return NodeType("simple_lockable") +} + +type SimpleLockableJSON struct { Name string `json:"name"` Owner *NodeID `json:"owner"` Dependencies []NodeID `json:"dependencies"` @@ -49,29 +66,31 @@ type BaseLockableStateJSON struct { LocksHeld map[NodeID]*NodeID `json:"locks_held"` } -func (state * BaseLockableState) Type() string { - return state._type +func (lockable * SimpleLockable) Serialize() ([]byte, error) { + lockable_json := NewSimpleLockableJSON(lockable) + return json.MarshalIndent(&lockable_json, "", " ") + } -func SaveBaseLockableState(state * BaseLockableState) BaseLockableStateJSON { - requirement_ids := make([]NodeID, len(state.requirements)) - for i, requirement := range(state.requirements) { +func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { + requirement_ids := make([]NodeID, len(lockable.requirements)) + for i, requirement := range(lockable.requirements) { requirement_ids[i] = requirement.ID() } - dependency_ids := make([]NodeID, len(state.dependencies)) - for i, dependency := range(state.dependencies) { + dependency_ids := make([]NodeID, len(lockable.dependencies)) + for i, dependency := range(lockable.dependencies) { dependency_ids[i] = dependency.ID() } var owner_id *NodeID = nil - if state.owner != nil { - new_str := state.owner.ID() + if lockable.owner != nil { + new_str := lockable.owner.ID() owner_id = &new_str } locks_held := map[NodeID]*NodeID{} - for lockable_id, node := range(state.locks_held) { + for lockable_id, node := range(lockable.locks_held) { if node == nil { locks_held[lockable_id] = nil } else { @@ -79,9 +98,8 @@ func SaveBaseLockableState(state * BaseLockableState) BaseLockableStateJSON { locks_held[lockable_id] = &str } } - return BaseLockableStateJSON{ - Type: state._type, - Name: state.name, + return SimpleLockableJSON{ + Name: lockable.name, Owner: owner_id, Dependencies: dependency_ids, Requirements: requirement_ids, @@ -89,79 +107,70 @@ func SaveBaseLockableState(state * BaseLockableState) BaseLockableStateJSON { } } -func (state * BaseLockableState) MarshalJSON() ([]byte, error) { - lockable_state := SaveBaseLockableState(state) - return json.Marshal(&lockable_state) +func (lockable * SimpleLockable) Name() string { + return lockable.name } -func (state * BaseLockableState) Name() string { - return state.name -} - -// Locks cannot be passed between base lockables, so the answer to -// "who used to own this lock held by a base lockable" is always "nobody" -func (state * BaseLockableState) ReturnLock(lockable_id NodeID) Lockable { - node, exists := state.locks_held[lockable_id] +func (lockable * SimpleLockable) RecordUnlock(l Lockable) Lockable { + lockable_id := l.ID() + last_owner, exists := lockable.locks_held[lockable_id] if exists == false { panic("Attempted to take a get the original lock holder of a lockable we don't own") } - delete(state.locks_held, lockable_id) - return node -} - -// Nothing can take a lock from a base lockable either -func (state * BaseLockableState) AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool { -// _, exists := state.locks_held[lockable_id] -// if exists == false { -// panic (fmt.Sprintf("%s tried to give away lock to %s but doesn't own it: %+v", node_id, lockable_id, state)) -// } - return false + delete(lockable.locks_held, lockable_id) + return last_owner } -func (state * BaseLockableState) RecordLockHolder(lockable_id NodeID, lock_holder Lockable) { - _, exists := state.locks_held[lockable_id] +func (lockable * SimpleLockable) RecordLock(l Lockable, last_owner Lockable) { + lockable_id := l.ID() + _, exists := lockable.locks_held[lockable_id] if exists == true { panic("Attempted to lock a lockable we're already holding(lock cycle)") } - state.locks_held[lockable_id] = lock_holder + lockable.locks_held[lockable_id] = last_owner } -func (state * BaseLockableState) Owner() Lockable { - return state.owner +// Nothing can take a lock from a simple lockable +func (lockable * SimpleLockable) AllowedToTakeLock(l Lockable, new_owner Lockable) bool { + return false +} + +func (lockable * SimpleLockable) Owner() Lockable { + return lockable.owner } -func (state * BaseLockableState) SetOwner(owner Lockable) { - state.owner = owner +func (lockable * SimpleLockable) SetOwner(owner Lockable) { + lockable.owner = owner } -func (state * BaseLockableState) Requirements() []Lockable { - return state.requirements +func (lockable * SimpleLockable) Requirements() []Lockable { + return lockable.requirements } -func (state * BaseLockableState) AddRequirement(requirement Lockable) { +func (lockable * SimpleLockable) AddRequirement(requirement Lockable) { if requirement == nil { panic("Will not connect nil to the DAG") } - state.requirements = append(state.requirements, requirement) + lockable.requirements = append(lockable.requirements, requirement) } -func (state * BaseLockableState) Dependencies() []Lockable { - return state.dependencies +func (lockable * SimpleLockable) Dependencies() []Lockable { + return lockable.dependencies } -func (state * BaseLockableState) AddDependency(dependency Lockable) { +func (lockable * SimpleLockable) AddDependency(dependency Lockable) { if dependency == nil { panic("Will not connect nil to the DAG") } - state.dependencies = append(state.dependencies, dependency) + lockable.dependencies = append(lockable.dependencies, dependency) } -func (state * BaseLockableState) RemoveDependency(dependency Lockable) { +func (lockable * SimpleLockable) RemoveDependency(dependency Lockable) { idx := -1 - for i, dep := range(state.dependencies) { + for i, dep := range(lockable.dependencies) { if dep.ID() == dependency.ID() { idx = i break @@ -169,17 +178,17 @@ func (state * BaseLockableState) RemoveDependency(dependency Lockable) { } if idx == -1 { - panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), state.Name())) + panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), lockable.Name())) } - dep_len := len(state.dependencies) - state.dependencies[idx] = state.dependencies[dep_len-1] - state.dependencies = state.dependencies[0:(dep_len-1)] + dep_len := len(lockable.dependencies) + lockable.dependencies[idx] = lockable.dependencies[dep_len-1] + lockable.dependencies = lockable.dependencies[0:(dep_len-1)] } -func (state * BaseLockableState) RemoveRequirement(requirement Lockable) { +func (lockable * SimpleLockable) RemoveRequirement(requirement Lockable) { idx := -1 - for i, req := range(state.requirements) { + for i, req := range(lockable.requirements) { if req.ID() == requirement.ID() { idx = i break @@ -187,19 +196,71 @@ func (state * BaseLockableState) RemoveRequirement(requirement Lockable) { } if idx == -1 { - panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), state.Name())) + panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), lockable.Name())) } - req_len := len(state.requirements) - state.requirements[idx] = state.requirements[req_len-1] - state.requirements = state.requirements[0:(req_len-1)] + req_len := len(lockable.requirements) + lockable.requirements[idx] = lockable.requirements[req_len-1] + lockable.requirements = lockable.requirements[0:(req_len-1)] +} + +func (lockable * SimpleLockable) CanLock(new_owner Lockable) error { + return nil +} + +func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { + return nil +} + +// lockable must already be locked for read +func (lockable * SimpleLockable) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { + if signal.Direction() == Up { + // Child->Parent, lockable updates dependency lockables + owner_sent := false + UseMoreStates(ctx, NodeList(lockable.dependencies), nodes, func(nodes NodeMap) error { + for _, dependency := range(lockable.dependencies){ + ctx.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) + dependency.Signal(ctx, signal, nodes) + if lockable.owner != nil { + if dependency.ID() == lockable.owner.ID() { + owner_sent = true + } + } + } + return nil + }) + if lockable.owner != nil && owner_sent == false { + if lockable.owner.ID() != lockable.ID() { + ctx.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) + UseMoreStates(ctx, []Node{lockable.owner}, nodes, func(nodes NodeMap) error { + return lockable.owner.Signal(ctx, signal, nodes) + }) + } + } + } else if signal.Direction() == Down { + // Parent->Child, lockable updates lock holder + UseMoreStates(ctx, NodeList(lockable.requirements), nodes, func(nodes NodeMap) error { + for _, requirement := range(lockable.requirements) { + err := requirement.Signal(ctx, signal, nodes) + if err != nil { + return err + } + } + return nil + }) + + } else if signal.Direction() == Direct { + } else { + panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + } + // Run the base update function, and return + return lockable.GraphNode.Signal(ctx, signal, nodes) } // Requires lockable and requirement's states to be locked for write -func UnlinkLockables(ctx * GraphContext, lockable Lockable, requirement Lockable) error { - state := lockable.State().(LockableState) - var found GraphNode = nil - for _, req := range(state.Requirements()) { +func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) error { + var found Node = nil + for _, req := range(lockable.Requirements()) { if requirement.ID() == req.ID() { found = req break @@ -210,15 +271,14 @@ func UnlinkLockables(ctx * GraphContext, lockable Lockable, requirement Lockable return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) } - req_state := found.State().(LockableState) - req_state.RemoveDependency(lockable) - state.RemoveRequirement(requirement) + requirement.RemoveDependency(lockable) + lockable.RemoveRequirement(requirement) return nil } // Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks -func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockable, nodes NodeMap) error { +func LinkLockables(ctx * Context, lockable Lockable, requirements []Lockable, nodes NodeMap) error { if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } @@ -245,42 +305,39 @@ func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockabl } // Check that all the requirements can be added - lockable_state := lockable.State().(LockableState) // If the lockable is already locked, need to lock this resource as well before we can add it for _, requirement := range(requirements) { - requirement_state := requirement.State().(LockableState) for _, req := range(requirements) { if req.ID() == requirement.ID() { continue } - if checkIfRequirement(ctx, req.ID(), requirement_state, requirement.ID(), nodes) == true { + if checkIfRequirement(ctx, req, requirement, nodes) == true { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) } } - if checkIfRequirement(ctx, lockable.ID(), requirement_state, requirement.ID(), nodes) == true { + if checkIfRequirement(ctx, lockable, requirement, nodes) == true { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) } - if checkIfRequirement(ctx, requirement.ID(), lockable_state, lockable.ID(), nodes) == true { + if checkIfRequirement(ctx, requirement, lockable, nodes) == true { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) } - if lockable_state.Owner() == nil { + if lockable.Owner() == nil { // If the new owner isn't locked, we can add the requirement - } else if requirement_state.Owner() == nil { + } else if requirement.Owner() == nil { // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) } else { // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement_state.Owner().ID() != lockable_state.Owner().ID() { + if requirement.Owner().ID() != lockable.Owner().ID() { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) } } } // Update the states of the requirements for _, requirement := range(requirements) { - requirement_state := requirement.State().(LockableState) - requirement_state.AddDependency(lockable) - lockable_state.AddRequirement(requirement) + requirement.AddDependency(lockable) + lockable.AddRequirement(requirement) ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) } @@ -288,78 +345,15 @@ func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockabl return nil } -func NewBaseLockableState(name string, _type string) BaseLockableState { - state := BaseLockableState{ - locks_held: map[NodeID]Lockable{}, - _type: _type, - name: name, - owner: nil, - requirements: []Lockable{}, - dependencies: []Lockable{}, - } - - return state -} - -type Lockable interface { - GraphNode - // Called when locking the node to allow for custom lock behaviour - Lock(node GraphNode, state LockableState) - // Called to check if the node can lock - CanLock(node GraphNode, state LockableState) error - // Called when unlocking the node to allow for custom lock behaviour - Unlock(node GraphNode, state LockableState) - // Called to check if the node can unlock - CanUnlock(node GraphNode, state LockableState) error -} - -// lockable's state must already be locked for read -func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) { - lockable_state := states[lockable.ID()].(LockableState) - if signal.Direction() == Up { - // Child->Parent, lockable updates dependency lockables - owner_sent := false - UseMoreStates(ctx, NodeList(lockable_state.Dependencies()), states, func(states NodeStateMap) error { - for _, dependency := range(lockable_state.Dependencies()){ - SendUpdate(ctx, dependency, signal, states) - if lockable_state.Owner() != nil { - if dependency.ID() != lockable_state.Owner().ID() { - owner_sent = true - } - } - } - return nil - }) - if lockable_state.Owner() != nil && owner_sent == false { - UseMoreStates(ctx, []GraphNode{lockable_state.Owner()}, states, func(states NodeStateMap) error { - SendUpdate(ctx, lockable_state.Owner(), signal, states) - return nil - }) - } - } else if signal.Direction() == Down { - // Parent->Child, lockable updates lock holder - UseMoreStates(ctx, NodeList(lockable_state.Requirements()), states, func(states NodeStateMap) error { - for _, requirement := range(lockable_state.Requirements()) { - SendUpdate(ctx, requirement, signal, states) - } - return nil - }) - - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) - } -} - -func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, nodes NodeMap) bool { +// Must be called withing update context +func checkIfRequirement(ctx * Context, r Lockable, cur Lockable, nodes NodeMap) bool { for _, c := range(cur.Requirements()) { - if c.ID() == r_id { + if c.ID() == r.ID() { return true } is_requirement := false - UpdateMoreStates(ctx, []GraphNode{c}, nodes, func(nodes NodeMap) (error) { - requirement_state := c.State().(LockableState) - is_requirement = checkIfRequirement(ctx, cur_id, requirement_state, c.ID(), nodes) + UpdateMoreStates(ctx, []Node{c}, nodes, func(nodes NodeMap) (error) { + is_requirement = checkIfRequirement(ctx, cur, c, nodes) return nil }) @@ -371,82 +365,57 @@ func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_ return false } -func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable , nodes NodeMap) error { +func LockLockables(ctx * Context, to_lock []Lockable, new_owner Lockable, nodes NodeMap) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } - for _, l := range(to_lock) { + + node_list := make([]Node, len(to_lock)) + for i, l := range(to_lock) { if l == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") } + node_list[i] = l } - if holder == nil { + if new_owner == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") } - holder_state := holder.State().(LockableState) // Called with no requirements to lock, success if len(to_lock) == 0 { return nil } - if holder_state == nil { - if len(to_lock) != 1 { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: if holder_state is nil, can only self-lock") - } else if holder.ID() != to_lock[0].ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: if holder_state is nil, can only self-lock") - } - } - - - node_list := make([]GraphNode, len(to_lock)) - for i, l := range(to_lock) { - node_list[i] = l - } - err := UpdateMoreStates(ctx, node_list, nodes, func(nodes NodeMap) error { // First loop is to check that the states can be locked, and locks all requirements for _, req := range(to_lock) { - req_state := req.State().(LockableState) - ctx.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), holder.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) // Check custom lock conditions - err := req.CanLock(holder, req_state) + err := req.CanLock(new_owner) if err != nil { return err } // If req is alreay locked, check that we can pass the lock - if req_state.Owner() != nil { - owner := req_state.Owner() - // Check if reqs owner will let holder take the lock from it - // The owner is either the same node, a node higher up in the dependency tree, or node outside the dependency tree(must be enforeced when linking dependencies) - // If the owner is the same node, we already have all the states we need to check lock passing - // If the owner is higher up in the dependency tree, we've either already got it's state getting to this node, or we won't try to get it's state as a dependency to lock this node, so we can grab the state and add it to a map - // If the owner is outside the dependency tree, then we won't try to grab it's lock trying to lock this node recursively - // So if the owner is the same node we don't need a new state, but if the owner is a different node then we need to grab it's state and add it to the list - if owner.ID() == holder.ID() { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s already owns %s, cannot lock again", holder.ID(), req.ID()) + if req.Owner() != nil { + owner := req.Owner() + if owner.ID() == new_owner.ID() { + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s already owns %s, cannot lock again", new_owner.ID(), req.ID()) } else if owner.ID() == req.ID() { - if req_state.AllowedToTakeLock(holder.ID(), req.ID()) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) + if req.AllowedToTakeLock(new_owner, req) == false { + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) } - // RECURSE: At this point either: - // 1) req has no children and the next LockLockables will return instantly - // a) in this case, we're holding every state mutex up to the resource being locked - // and all the owners passing a lock, so we can start to change state - // 2) req has children, and we will recurse(checking that locking is allowed) until we reach a leaf and can release the locks as we change state. The call will either return nil if state has changed, on an error if no state has changed - err := LockLockables(ctx, req_state.Requirements(), req, nodes) + err := LockLockables(ctx, req.Requirements(), req, nodes) if err != nil { return err } } else { - err := UpdateMoreStates(ctx, []GraphNode{owner}, nodes, func(nodes NodeMap)(error){ - owner_state := owner.State().(LockableState) - if owner_state.AllowedToTakeLock(holder.ID(), req.ID()) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) + err := UpdateMoreStates(ctx, []Node{owner}, nodes, func(nodes NodeMap)(error){ + if owner.AllowedToTakeLock(new_owner, req) == false { + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) } - err := LockLockables(ctx, req_state.Requirements(), req, nodes) + err := LockLockables(ctx, req.Requirements(), req, nodes) return err }) if err != nil { @@ -454,7 +423,7 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable , nod } } } else { - err := LockLockables(ctx, req_state.Requirements(), req, nodes) + err := LockLockables(ctx, req.Requirements(), req, nodes) if err != nil { return err } @@ -463,19 +432,13 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable , nod // At this point state modification will be started, so no errors can be returned for _, req := range(to_lock) { - req_state := req.State().(LockableState) - old_owner := req_state.Owner() - req_state.SetOwner(holder) - if req.ID() == holder.ID() { - req_state.RecordLockHolder(req.ID(), old_owner) - } else { - holder_state.RecordLockHolder(req.ID(), old_owner) - } - req.Lock(holder, req_state) + old_owner := req.Owner() + req.SetOwner(new_owner) + new_owner.RecordLock(req, old_owner) if old_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", holder.ID(), req.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) } else { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", holder.ID(), req.ID(), old_owner.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) } } return nil @@ -483,7 +446,7 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable , nod return err } -func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, nodes NodeMap) error { +func UnlockLockables(ctx * Context, to_unlock []Lockable, old_owner Lockable, nodes NodeMap) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } @@ -492,25 +455,16 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not lock nil") } } - if holder == nil { + if old_owner == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") } - holder_state := holder.State().(LockableState) // Called with no requirements to lock, success if len(to_unlock) == 0 { return nil } - if holder_state == nil { - if len(to_unlock) != 1 { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: if holder_state is nil, can only self-lock") - } else if holder.ID() != to_unlock[0].ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: if holder_state is nil, can only self-lock") - } - } - - node_list := make([]GraphNode, len(to_unlock)) + node_list := make([]Node, len(to_unlock)) for i, l := range(to_unlock) { node_list[i] = l } @@ -518,25 +472,24 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, err := UpdateMoreStates(ctx, node_list, nodes, func(nodes NodeMap) error { // First loop is to check that the states can be locked, and locks all requirements for _, req := range(to_unlock) { - req_state := req.State().(LockableState) - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), holder.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) // Check if the owner is correct - if req_state.Owner() != nil { - if req_state.Owner().ID() != holder.ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), holder.ID()) + if req.Owner() != nil { + if req.Owner().ID() != old_owner.ID() { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) } } else { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) } // Check custom unlock conditions - err := req.CanUnlock(holder, req_state) + err := req.CanUnlock(old_owner) if err != nil { return err } - err = UnlockLockables(ctx, req_state.Requirements(), req, nodes) + err = UnlockLockables(ctx, req.Requirements(), req, nodes) if err != nil { return err } @@ -544,19 +497,12 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, // At this point state modification will be started, so no errors can be returned for _, req := range(to_unlock) { - req_state := req.State().(LockableState) - var new_owner Lockable = nil - if holder_state == nil { - new_owner = req_state.ReturnLock(req.ID()) - } else { - new_owner = holder_state.ReturnLock(req.ID()) - } - req_state.SetOwner(new_owner) - req.Unlock(holder, req_state) + new_owner := old_owner.RecordUnlock(req) + req.SetOwner(new_owner) if new_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", holder.ID(), req.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) } else { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", holder.ID(), req.ID(), new_owner.ID()) + ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) } } return nil @@ -564,160 +510,98 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, return err } -// BaseLockables represent simple lockables in the DAG that can be used to create a hierarchy of locks that store names -type BaseLockable struct { - BaseNode -} - -//BaseLockables don't check anything special when locking/unlocking -func (lockable * BaseLockable) CanLock(node GraphNode, state LockableState) error { - return nil -} -func (lockable * BaseLockable) CanUnlock(node GraphNode, state LockableState) error { - return nil -} - -//BaseLockables don't check anything special when locking/unlocking -func (lockable * BaseLockable) Lock(node GraphNode, state LockableState) { - return -} - -func (lockable * BaseLockable) Unlock(node GraphNode, state LockableState) { - return -} - -func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, error) { - base_node, err := NewNode(ctx, state) +func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j SimpleLockableJSON + err := json.Unmarshal(data, &j) if err != nil { - return BaseLockable{}, err - } - - lockable := BaseLockable{ - BaseNode: base_node, + return nil, err } - return lockable, nil -} + lockable := NewSimpleLockable(id, j.Name) + nodes[id] = &lockable -func RestoreBaseLockable(ctx * GraphContext, id NodeID) BaseLockable { - base_node := RestoreNode(ctx, id) - return BaseLockable{ - BaseNode: base_node, + err = RestoreSimpleLockable(ctx, &lockable, j, nodes) + if err != nil { + return nil, err } -} -func LoadSimpleLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { - // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id - lockable := RestoreBaseLockable(ctx, id) return &lockable, nil } -func RestoreBaseLockableState(ctx * GraphContext, j BaseLockableStateJSON, loaded_nodes NodeMap) (*BaseLockableState, error) { - state := BaseLockableState{ - _type: j.Type, - name: j.Name, + +func NewSimpleLockable(id NodeID, name string) SimpleLockable { + return SimpleLockable{ + GraphNode: NewGraphNode(id), + name: name, owner: nil, - dependencies: make([]Lockable, len(j.Dependencies)), - requirements: make([]Lockable, len(j.Requirements)), + requirements: []Lockable{}, + dependencies: []Lockable{}, locks_held: map[NodeID]Lockable{}, } +} +func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSON, nodes NodeMap) error { if j.Owner != nil { - o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes) + o, err := LoadNodeRecurse(ctx, *j.Owner, nodes) if err != nil { - return nil, err + return err } o_l, ok := o.(Lockable) if ok == false { - return nil, err + return fmt.Errorf("%s is not a Lockable", o.ID()) } - state.owner = o_l + lockable.SetOwner(o_l) } - for i, dep := range(j.Dependencies) { - dep_node, err := LoadNodeRecurse(ctx, dep, loaded_nodes) + for _, dep := range(j.Dependencies) { + dep_node, err := LoadNodeRecurse(ctx, dep, nodes) if err != nil { - return nil, err + return err } dep_l, ok := dep_node.(Lockable) if ok == false { - return nil, fmt.Errorf("%+v is not a Lockable as expected", dep_node) + return fmt.Errorf("%+v is not a Lockable as expected", dep_node) } - state.dependencies[i] = dep_l + lockable.AddDependency(dep_l) } - for i, req := range(j.Requirements) { - req_node, err := LoadNodeRecurse(ctx, req, loaded_nodes) + for _, req := range(j.Requirements) { + req_node, err := LoadNodeRecurse(ctx, req, nodes) if err != nil { - return nil, err + return err } req_l, ok := req_node.(Lockable) if ok == false { - return nil, fmt.Errorf("%+v is not a Lockable as expected", req_node) + return fmt.Errorf("%+v is not a Lockable as expected", req_node) } - state.requirements[i] = req_l + lockable.AddRequirement(req_l) } for l_id, h_id := range(j.LocksHeld) { - _, err := LoadNodeRecurse(ctx, l_id, loaded_nodes) + l, err := LoadNodeRecurse(ctx, l_id, nodes) if err != nil { - return nil, err + return err + } + l_l, ok := l.(Lockable) + if ok == false { + return fmt.Errorf("%s is not a Lockable", l.ID()) } + var h_l Lockable = nil if h_id != nil { - h_node, err := LoadNodeRecurse(ctx, *h_id, loaded_nodes) + h_node, err := LoadNodeRecurse(ctx, *h_id, nodes) if err != nil { - return nil, err + return err } h, ok := h_node.(Lockable) if ok == false { - return nil, err + return err } h_l = h } - state.locks_held[l_id] = h_l + lockable.RecordLock(l_l, h_l) } - return &state, nil -} - -func LoadSimpleLockableState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ - var j BaseLockableStateJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - state, err := RestoreBaseLockableState(ctx, j, loaded_nodes) - if err != nil { - return nil, err - } - - return state, nil -} - -func NewSimpleLockable(ctx * GraphContext, name string, requirements []Lockable) (*BaseLockable, error) { - state := NewBaseLockableState(name, "simple_lockable") - lockable, err := NewBaseLockable(ctx, &state) - if err != nil { - return nil, err - } - - lockable_ptr := &lockable - - if len(requirements) > 0 { - req_nodes := make([]GraphNode, len(requirements)) - for i, req := range(requirements) { - req_nodes[i] = req - } - err = UpdateStates(ctx, req_nodes, func(nodes NodeMap) error { - return LinkLockables(ctx, lockable_ptr, requirements, nodes) - }) - if err != nil { - return nil, err - } - } - - return lockable_ptr, nil + return nil } diff --git a/lockable_test.go b/lockable_test.go index b1e15d8..27ffef5 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -4,38 +4,38 @@ import ( "testing" "fmt" "time" - "encoding/json" ) func TestNewSimpleLockable(t * testing.T) { ctx := testContext(t) - l1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r - l2, err := NewSimpleLockable(ctx, "Test lockable 2", []Lockable{l1}) + err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { + return LinkLockables(ctx, l2, []Lockable{l1}, nodes) + }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{l1, l2}, func(states NodeStateMap) error { - l1_state := states[l1.ID()].(LockableState) - l2_state := states[l2.ID()].(LockableState) - - l1_deps := len(l1_state.Dependencies()) + err = UseStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { + l1_deps := len(l1.Dependencies()) if l1_deps != 1 { return fmt.Errorf("l1 has wront amount of dependencies %d/1", l1_deps) } - l1_dep1 := l1_state.Dependencies()[0] + l1_dep1 := l1.Dependencies()[0] if l1_dep1.ID() != l2.ID() { return fmt.Errorf("Wrong dependency for l1, %s instead of %s", l1_dep1.ID(), l2.ID()) } - l2_reqs := len(l2_state.Requirements()) + l2_reqs := len(l2.Requirements()) if l2_reqs != 1 { return fmt.Errorf("l2 has wrong amount of requirements %d/1", l2_reqs) } - l2_req1 := l2_state.Requirements()[0] + l2_req1 := l2.Requirements()[0] if l2_req1.ID() != l1.ID() { return fmt.Errorf("Wrong requirement for l2, %s instead of %s", l2_req1.ID(), l1.ID()) } @@ -47,10 +47,16 @@ func TestNewSimpleLockable(t * testing.T) { func TestRepeatedChildLockable(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r + + err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { + return LinkLockables(ctx, l2, []Lockable{l1, l1}, nodes) + }) - _, err = NewSimpleLockable(ctx, "Test lockable 2", []Lockable{r1, r1}) if err == nil { t.Fatal("Added the same lockable as a requirement twice to the same lockable") } @@ -59,32 +65,34 @@ func TestRepeatedChildLockable(t * testing.T) { func TestLockableSelfLock(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r - err = UpdateStates(ctx, []GraphNode{r1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{r1}, r1, nodes) + err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return LockLockables(ctx, []Lockable{l1}, l1, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner_id := states[r1.ID()].(LockableState).Owner().ID() - if owner_id != r1.ID() { - return fmt.Errorf("r1 is owned by %s instead of self", owner_id) + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner_id := NodeID("") + if l1.owner != nil { + owner_id = l1.owner.ID() + } + if owner_id != l1.ID() { + return fmt.Errorf("l1 is owned by %s instead of self", owner_id) } return nil }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{r1}, r1, nodes) + err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner := states[r1.ID()].(LockableState).Owner() - if owner != nil { - return fmt.Errorf("r1 is not unowned after unlock: %s", owner.ID()) + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + if l1.owner != nil { + return fmt.Errorf("l1 is not unowned after unlock: %s", l1.owner.ID()) } return nil }) @@ -95,53 +103,61 @@ func TestLockableSelfLock(t * testing.T) { func TestLockableSelfLockTiered(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) - - r2, err := NewSimpleLockable(ctx, "Test lockable 2", []Lockable{}) - fatalErr(t, err) - - r3, err := NewSimpleLockable(ctx, "Test lockable 3", []Lockable{r1, r2}) - fatalErr(t, err) - - err = UpdateStates(ctx, []GraphNode{r3}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{r3}, r3, nodes) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test lockable 3") + l3 := &l3_r + + err := UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) + if err != nil { + return err + } + return LockLockables(ctx, []Lockable{l3}, l3, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states NodeStateMap) (error) { - owner_1_id := states[r1.ID()].(LockableState).Owner().ID() - if owner_1_id != r3.ID() { - return fmt.Errorf("r1 is owned by %s instead of r3", owner_1_id) + err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { + owner_1 := NodeID("") + if l1.owner != nil { + owner_1 = l1.owner.ID() + } + if owner_1 != l3.ID() { + return fmt.Errorf("l1 is owned by %s instead of l3", owner_1) } - owner_2_id := states[r2.ID()].(LockableState).Owner().ID() - if owner_2_id != r3.ID() { - return fmt.Errorf("r2 is owned by %s instead of r3", owner_2_id) + owner_2 := NodeID("") + if l2.owner != nil { + owner_2 = l2.owner.ID() + } + if owner_2 != l3.ID() { + return fmt.Errorf("l2 is owned by %s instead of l3", owner_2) } return nil }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r3}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{r3}, r3, nodes) + err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { + return UnlockLockables(ctx, []Lockable{l3}, l3, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states NodeStateMap) (error) { - owner_1 := states[r1.ID()].(LockableState).Owner() + err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { + owner_1 := l1.owner if owner_1 != nil { - return fmt.Errorf("r1 is not unowned after unlocking: %s", owner_1.ID()) + return fmt.Errorf("l1 is not unowned after unlocking: %s", owner_1.ID()) } - owner_2 := states[r2.ID()].(LockableState).Owner() + owner_2 := l2.owner if owner_2 != nil { - return fmt.Errorf("r2 is not unowned after unlocking: %s", owner_2.ID()) + return fmt.Errorf("l2 is not unowned after unlocking: %s", owner_2.ID()) } - owner_3 := states[r3.ID()].(LockableState).Owner() + owner_3 := l3.owner if owner_3 != nil { - return fmt.Errorf("r3 is not unowned after unlocking: %s", owner_3.ID()) + return fmt.Errorf("l3 is not unowned after unlocking: %s", owner_3.ID()) } return nil }) @@ -152,40 +168,42 @@ func TestLockableSelfLockTiered(t * testing.T) { func TestLockableLockOther(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) - - r2, err := NewSimpleLockable(ctx, "Test lockable 2", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r - err = UpdateStates(ctx, []GraphNode{r1, r2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{r1}, r2, nodes) + err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) (error) { + err := LockLockables(ctx, []Lockable{l1}, l2, nodes) fatalErr(t, err) return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner_id := states[r1.ID()].(LockableState).Owner().ID() - if owner_id != r2.ID() { - return fmt.Errorf("r1 is owned by %s instead of r2", owner_id) + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner_id := NodeID("") + if l1.owner != nil { + owner_id = l1.owner.ID() + } + if owner_id != l2.ID() { + return fmt.Errorf("l1 is owned by %s instead of l2", owner_id) } return nil }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(nodes NodeMap) (error) { - err := UnlockLockables(ctx, []Lockable{r1}, r2, nodes) + err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { + err := UnlockLockables(ctx, []Lockable{l1}, l2, nodes) fatalErr(t, err) return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner := states[r1.ID()].(LockableState).Owner() + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner := l1.owner if owner != nil { - return fmt.Errorf("r1 is owned by %s instead of r2", owner.ID()) + return fmt.Errorf("l1 is owned by %s instead of l2", owner.ID()) } return nil @@ -197,46 +215,48 @@ func TestLockableLockOther(t * testing.T) { func TestLockableLockSimpleConflict(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r - r2, err := NewSimpleLockable(ctx, "Test lockable 2", []Lockable{}) - fatalErr(t, err) - - err = UpdateStates(ctx, []GraphNode{r1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{r1}, r1, nodes) + err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return LockLockables(ctx, []Lockable{l1}, l1, nodes) }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{r1}, r2, nodes) + err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { + err := LockLockables(ctx, []Lockable{l1}, l2, nodes) if err == nil { - t.Fatal("r2 took r1's lock from itself") + t.Fatal("l2 took l1's lock from itself") } return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner_id := states[r1.ID()].(LockableState).Owner().ID() - if owner_id != r1.ID() { - return fmt.Errorf("r1 is owned by %s instead of r1", owner_id) + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner_id := NodeID("") + if l1.owner != nil { + owner_id = l1.owner.ID() + } + if owner_id != l1.ID() { + return fmt.Errorf("l1 is owned by %s instead of l1", owner_id) } return nil }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{r1}, r1, nodes) + err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { - owner := states[r1.ID()].(LockableState).Owner() + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner := l1.owner if owner != nil { - return fmt.Errorf("r1 is owned by %s instead of r1", owner.ID()) + return fmt.Errorf("l1 is owned by %s instead of l1", owner.ID()) } return nil @@ -248,40 +268,47 @@ func TestLockableLockSimpleConflict(t * testing.T) { func TestLockableLockTieredConflict(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleLockable(ctx, "Test lockable 1", []Lockable{}) - fatalErr(t, err) - - r2, err := NewSimpleLockable(ctx, "Test lockable 2", []Lockable{r1}) - fatalErr(t, err) - - r3, err := NewSimpleLockable(ctx, "Test lockable 3", []Lockable{r1}) + l1_r := NewSimpleLockable(RandID(), "Test lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test lockable 3") + l3 := &l3_r + + err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) + if err != nil { + return err + } + return LinkLockables(ctx, l3, []Lockable{l1}, nodes) + }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{r2}, r2, nodes) + err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) error { + return LockLockables(ctx, []Lockable{l2}, l2, nodes) }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r3}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{r3}, r3, nodes) + err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { + return LockLockables(ctx, []Lockable{l3}, l3, nodes) }) if err == nil { - t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked") + t.Fatal("Locked l3 which depends on l1 while l2 which depends on l1 is already locked") } } func TestLockableSimpleUpdate(t * testing.T) { ctx := logTestContext(t, []string{"test", "update", "lockable"}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + - update_channel := l1.UpdateChannel(1) + update_channel := UpdateChannel(l1, 1, "test") go func() { - UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error { - SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"), states) - return nil + UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return l1.Signal(ctx, NewDirectSignal(l1, "test_update"), nodes) }) }() @@ -291,21 +318,26 @@ func TestLockableSimpleUpdate(t * testing.T) { func TestLockableDownUpdate(t * testing.T) { ctx := logTestContext(t, []string{"test", "update", "lockable"}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) - - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) - fatalErr(t, err) - - _, err = NewSimpleLockable(ctx, "Test Lockable 3", []Lockable{l2}) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") + l3 := &l3_r + err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) + if err != nil { + return err + } + return LinkLockables(ctx, l3, []Lockable{l2}, nodes) + }) fatalErr(t, err) - update_channel := l1.UpdateChannel(1) + update_channel := UpdateChannel(l1, 1, "test") go func() { - UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error { - SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"), states) - return nil + UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { + return l2.Signal(ctx, NewDownSignal(l2, "test_update"), nodes) }) }() @@ -315,21 +347,26 @@ func TestLockableDownUpdate(t * testing.T) { func TestLockableUpUpdate(t * testing.T) { ctx := logTestContext(t, []string{"test", "update", "lockable"}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) - - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) - fatalErr(t, err) - - l3, err := NewSimpleLockable(ctx, "Test Lockable 3", []Lockable{l2}) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") + l3 := &l3_r + err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) + if err != nil { + return err + } + return LinkLockables(ctx, l3, []Lockable{l2}, nodes) + }) fatalErr(t, err) - update_channel := l3.UpdateChannel(1) + update_channel := UpdateChannel(l3, 1, "test") go func() { - UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error { - SendUpdate(ctx, l2, NewSignal(l2, "test_update"), states) - return nil + UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { + return l2.Signal(ctx, NewSignal(l2, "test_update"), nodes) }) }() @@ -337,21 +374,29 @@ func TestLockableUpUpdate(t * testing.T) { } func TestOwnerNotUpdatedTwice(t * testing.T) { - ctx := logTestContext(t, []string{"test", "update", "lockable"}) + ctx := logTestContext(t, []string{"test", "signal", "lockable", "listeners"}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) + err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) + if err != nil { + return err + } + return LockLockables(ctx, []Lockable{l2}, l2, nodes) + }) fatalErr(t, err) - update_channel := l2.UpdateChannel(1) + update_channel := UpdateChannel(l2, 1, "test") go func() { - UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error { - SendUpdate(ctx, l1, NewSignal(l1, "test_update"), states) - return nil + err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { + return l1.Signal(ctx, NewSignal(l1, "test_update"), nodes) }) + fatalErr(t, err) }() (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") @@ -360,11 +405,22 @@ func TestOwnerNotUpdatedTwice(t * testing.T) { func TestLockableDependencyOverlap(t * testing.T) { ctx := testContext(t) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) - fatalErr(t, err) - _, err = NewSimpleLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) + + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") + l3 := &l3_r + + err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) + if err != nil { + return err + } + + return LinkLockables(ctx, l3, []Lockable{l2, l1}, nodes) + }) if err == nil { t.Fatal("Should have thrown an error because of dependency overlap") } @@ -372,45 +428,66 @@ func TestLockableDependencyOverlap(t * testing.T) { func TestLockableDBLoad(t * testing.T){ ctx := logTestContext(t, []string{}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{}) - fatalErr(t, err) - l3, err := NewSimpleLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) - fatalErr(t, err) - l4, err := NewSimpleLockable(ctx, "Test Lockable 4", []Lockable{l3}) - fatalErr(t, err) - _, err = NewSimpleLockable(ctx, "Test Lockable 5", []Lockable{l4}) - fatalErr(t, err) - l6, err := NewSimpleLockable(ctx, "Test Lockable 6", []Lockable{}) - err = UpdateStates(ctx, []GraphNode{l6, l3}, func(nodes NodeMap) error { - err := LockLockables(ctx, []Lockable{l3}, l6, nodes) - return err + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r + l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") + l3 := &l3_r + l4_r := NewSimpleLockable(RandID(), "Test Lockable 4") + l4 := &l4_r + l5_r := NewSimpleLockable(RandID(), "Test Lockable 5") + l5 := &l5_r + l6_r := NewSimpleLockable(RandID(), "Test Lockable 6") + l6 := &l6_r + err := UpdateStates(ctx, []Node{l1, l2, l3, l4, l5, l6}, func(nodes NodeMap) error { + err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) + if err != nil { + return err + } + + err = LinkLockables(ctx, l4, []Lockable{l3}, nodes) + if err != nil { + return err + } + + err = LinkLockables(ctx, l5, []Lockable{l4}, nodes) + if err != nil { + return err + } + return LockLockables(ctx, []Lockable{l3}, l6, nodes) }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error { - ser, err := json.MarshalIndent(states[l3.ID()], "", " ") + + err = UseStates(ctx, []Node{l3}, func(nodes NodeMap) error { + ser, err := l3.Serialize() fmt.Printf("\n%s\n\n", ser) return err }) + fatalErr(t, err) l3_loaded, err := LoadNode(ctx, l3.ID()) fatalErr(t, err) // TODO: add more equivalence checks - err = UseStates(ctx, []GraphNode{l3_loaded}, func(states NodeStateMap) error { - ser, err := json.MarshalIndent(states[l3_loaded.ID()], "", " ") + err = UseStates(ctx, []Node{l3_loaded}, func(nodes NodeMap) error { + ser, err := l3_loaded.Serialize() fmt.Printf("\n%s\n\n", ser) return err }) + fatalErr(t, err) } func TestLockableUnlink(t * testing.T){ ctx := logTestContext(t, []string{"lockable"}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") + l2 := &l2_r - l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) + err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { + return LinkLockables(ctx, l2, []Lockable{l1}, nodes) + }) fatalErr(t, err) err = UnlinkLockables(ctx, l2, l1) diff --git a/log.go b/log.go new file mode 100644 index 0000000..034a760 --- /dev/null +++ b/log.go @@ -0,0 +1,97 @@ +package graphvent + +import ( + "fmt" + "github.com/rs/zerolog" + "os" + "sync" + "encoding/json" +) + +// A Logger is passed around to record events happening to components enabled by SetComponents +type Logger interface { + SetComponents(components []string) error + // Log a formatted string + Logf(component string, format string, items ... interface{}) + // Log a map of attributes and a format string + Logm(component string, fields map[string]interface{}, format string, items ... interface{}) + // Log a structure to a file by marshalling and unmarshalling the json + Logj(component string, s interface{}, format string, items ... interface{}) +} + +func NewConsoleLogger(components []string) *ConsoleLogger { + logger := &ConsoleLogger{ + loggers: map[string]zerolog.Logger{}, + components: []string{}, + } + + logger.SetComponents(components) + + return logger +} + +// A ConsoleLogger logs to stdout +type ConsoleLogger struct { + loggers map[string]zerolog.Logger + components_lock sync.Mutex + components []string +} + +func (logger * ConsoleLogger) SetComponents(components []string) error { + logger.components_lock.Lock() + defer logger.components_lock.Unlock() + + component_enabled := func (component string) bool { + for _, c := range(components) { + if c == component { + return true + } + } + return false + } + + for c, _ := range(logger.loggers) { + if component_enabled(c) == false { + delete(logger.loggers, c) + } + } + + for _, c := range(components) { + _, exists := logger.loggers[c] + if component_enabled(c) == true && exists == false { + logger.loggers[c] = zerolog.New(os.Stdout).With().Timestamp().Str("component", c).Logger() + } + } + return nil +} + +func (logger * ConsoleLogger) Logm(component string, fields map[string]interface{}, format string, items ... interface{}) { + l, exists := logger.loggers[component] + if exists == true { + log := l.Log() + for key, value := range(fields) { + log = log.Str(key, fmt.Sprintf("%+v", value)) + } + log.Msg(fmt.Sprintf(format, items...)) + } +} + +func (logger * ConsoleLogger) Logf(component string, format string, items ... interface{}) { + l, exists := logger.loggers[component] + if exists == true { + l.Log().Msg(fmt.Sprintf(format, items...)) + } +} + +func (logger * ConsoleLogger) Logj(component string, s interface{}, format string, items ... interface{}) { + m := map[string]interface{}{} + ser, err := json.Marshal(s) + if err != nil { + panic("LOG_MARSHAL_ERR") + } + err = json.Unmarshal(ser, &m) + if err != nil { + panic("LOG_UNMARSHAL_ERR") + } + logger.Logm(component, m, format, items...) +} diff --git a/node.go b/node.go new file mode 100644 index 0000000..20a28ec --- /dev/null +++ b/node.go @@ -0,0 +1,388 @@ +package graphvent + +import ( + "sync" + "github.com/google/uuid" + badger "github.com/dgraph-io/badger/v3" + "fmt" + "encoding/binary" + "crypto/sha256" +) + +// IDs are how nodes are uniquely identified, and can be serialized for the database +type NodeID string + +func (id NodeID) Serialize() []byte { + return []byte(id) +} + +// Types are how nodes are associated with structs at runtime(and from the DB) +type NodeType string +func (node_type NodeType) Hash() uint64 { + hash := sha256.New() + hash.Write([]byte(node_type)) + bytes := hash.Sum(nil) + + return binary.BigEndian.Uint64(bytes[(len(bytes)-9):(len(bytes)-1)]) +} + +// Generate a random id +func RandID() NodeID { + uuid_str := uuid.New().String() + return NodeID(uuid_str) +} + +// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) +// RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node +type Node interface { + sync.Locker + RLock() + RUnlock() + Serialize() ([]byte, error) + ID() NodeID + Type() NodeType + Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error + RegisterChannel(id NodeID, listener chan GraphSignal) + UnregisterChannel(id NodeID) +} + +// A GraphNode is an implementation of a Node that can be embedded into more complex structures +type GraphNode struct { + sync.RWMutex + listeners_lock sync.Mutex + + id NodeID + listeners map[NodeID]chan GraphSignal +} + +// GraphNode doesn't serialize any additional information by default +func (node * GraphNode) Serialize() ([]byte, error) { + return nil, nil +} + +func LoadGraphNode(ctx * Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { + if len(data) > 0 { + return nil, fmt.Errorf("Attempted to load a graph_node with data %+v, should have been 0 length", string(data)) + } + node := NewGraphNode(id) + return &node, nil +} + +func (node * GraphNode) ID() NodeID { + return node.id +} + +func (node * GraphNode) Type() NodeType { + return NodeType("graph_node") +} + +func (node * GraphNode) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { + ctx.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) + node.listeners_lock.Lock() + defer node.listeners_lock.Unlock() + closed := []NodeID{} + + for id, listener := range node.listeners { + ctx.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) + select { + case listener <- signal: + default: + ctx.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) + go func(node Node, listener chan GraphSignal) { + listener <- NewDirectSignal(node, "listener_closed") + close(listener) + }(node, listener) + closed = append(closed, id) + } + } + + for _, id := range(closed) { + delete(node.listeners, id) + } + return nil +} + +func (node * GraphNode) RegisterChannel(id NodeID, listener chan GraphSignal) { + node.listeners_lock.Lock() + _, exists := node.listeners[id] + if exists == false { + node.listeners[id] = listener + } + node.listeners_lock.Unlock() +} + +func (node * GraphNode) UnregisterChannel(id NodeID) { + node.listeners_lock.Lock() + _, exists := node.listeners[id] + if exists == false { + panic("Attempting to unregister non-registered listener") + } else { + delete(node.listeners, id) + } + node.listeners_lock.Unlock() +} + +func NewGraphNode(id NodeID) GraphNode { + return GraphNode{ + id: id, + listeners: map[NodeID]chan GraphSignal{}, + } +} + +const NODE_DB_MAGIC = 0x2491df14 +const NODE_DB_HEADER_LEN = 12 +type DBHeader struct { + Magic uint32 + TypeHash uint64 +} + +func (header DBHeader) Serialize() []byte { + if header.Magic != NODE_DB_MAGIC { + panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic)) + } + + ret := make([]byte, NODE_DB_HEADER_LEN) + binary.BigEndian.PutUint32(ret[0:4], header.Magic) + binary.BigEndian.PutUint64(ret[4:12], header.TypeHash) + return ret +} + +func NewDBHeader(node_type NodeType) DBHeader { + return DBHeader{ + Magic: NODE_DB_MAGIC, + TypeHash: node_type.Hash(), + } +} + +func getNodeBytes(ctx * Context, node Node) ([]byte, error) { + if node == nil { + return nil, fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node") + } + ser, err := node.Serialize() + if err != nil { + return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %e", err) + } + + header := NewDBHeader(node.Type()) + + db_data := append(header.Serialize(), ser...) + + return db_data, nil +} + +// Write a node to the database +func WriteNode(ctx * Context, node Node) error { + ctx.Log.Logf("db", "DB_WRITE: %+v", node) + + node_bytes, err := getNodeBytes(ctx, node) + if err != nil { + return err + } + + id_ser := node.ID().Serialize() + + err = ctx.DB.Update(func(txn *badger.Txn) error { + err := txn.Set(id_ser, node_bytes) + return err + }) + + return err +} + +// Write multiple nodes to the database in a single transaction +func WriteNodes(ctx * Context, nodes NodeMap) error { + ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) + if nodes == nil { + return fmt.Errorf("Cannot write nil map to DB") + } + + serialized_bytes := make([][]byte, len(nodes)) + serialized_ids := make([][]byte, len(nodes)) + i := 0 + for _, node := range(nodes) { + node_bytes, err := getNodeBytes(ctx, node) + if err != nil { + return err + } + + id_ser := node.ID().Serialize() + + serialized_bytes[i] = node_bytes + serialized_ids[i] = id_ser + + i++ + } + + err := ctx.DB.Update(func(txn *badger.Txn) error { + for i, id := range(serialized_ids) { + err := txn.Set(id, serialized_bytes[i]) + if err != nil { + return err + } + } + return nil + }) + + return err +} + +// Get the bytes associates with `id` in the database, or error +func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) { + var bytes []byte + err := ctx.DB.View(func(txn *badger.Txn) error { + item, err := txn.Get(id.Serialize()) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + bytes = append([]byte{}, val...) + return nil + }) + }) + + if err != nil { + ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err) + return 0, nil, err + } + + if len(bytes) < NODE_DB_HEADER_LEN { + return 0, nil, fmt.Errorf("header for %s is %d/%d bytes", id, len(bytes), NODE_DB_HEADER_LEN) + } + + header := DBHeader{} + header.Magic = binary.BigEndian.Uint32(bytes[0:4]) + header.TypeHash = binary.BigEndian.Uint64(bytes[4:12]) + + if header.Magic != NODE_DB_MAGIC { + return 0, nil, fmt.Errorf("header for %s, invalid magic 0x%x", id, header.Magic) + } + + node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN) + copy(node_bytes, bytes[NODE_DB_HEADER_LEN:]) + + ctx.Log.Logf("db", "DB_READ: %s - %s", id, string(bytes)) + + return header.TypeHash, node_bytes, nil +} + +func LoadNode(ctx * Context, id NodeID) (Node, error) { + nodes := NodeMap{} + return LoadNodeRecurse(ctx, id, nodes) +} + +func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) { + node, exists := nodes[id] + if exists == false { + type_hash, bytes, err := readNodeBytes(ctx, id) + if err != nil { + return nil, err + } + + node_type, exists := ctx.Types[type_hash] + if exists == false { + return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes) + } + + if node_type.Load == nil { + return nil, fmt.Errorf("0x%x is an invalid node type, nil Load", type_hash) + } + + node, err = node_type.Load(ctx, id, bytes, nodes) + if err != nil { + return nil, err + } + + ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) + } + return node, nil +} + +func checkForDuplicate(nodes []Node) error { + found := map[NodeID]bool{} + for _, node := range(nodes) { + if node == nil { + return fmt.Errorf("Cannot get state of nil node") + } + + _, exists := found[node.ID()] + if exists == true { + return fmt.Errorf("Attempted to get state of %s twice", node.ID()) + } + found[node.ID()] = true + } + return nil +} + +func NodeList[K Node](list []K) []Node { + nodes := make([]Node, len(list)) + for i, node := range(list) { + nodes[i] = node + } + return nodes +} + +type NodeMap map[NodeID]Node +type NodesFn func(nodes NodeMap) error +func UseStates(ctx * Context, init_nodes []Node, nodes_fn NodesFn) error { + nodes := NodeMap{} + return UseMoreStates(ctx, init_nodes, nodes, nodes_fn) +} +func UseMoreStates(ctx * Context, new_nodes []Node, nodes NodeMap, nodes_fn NodesFn) error { + err := checkForDuplicate(new_nodes) + if err != nil { + return err + } + + locked_nodes := []Node{} + for _, node := range(new_nodes) { + _, locked := nodes[node.ID()] + if locked == false { + node.RLock() + nodes[node.ID()] = node + locked_nodes = append(locked_nodes, node) + } + } + + err = nodes_fn(nodes) + + for _, node := range(locked_nodes) { + delete(nodes, node.ID()) + node.RUnlock() + } + + return err +} + +func UpdateStates(ctx * Context, nodes []Node, nodes_fn NodesFn) error { + locked_nodes := NodeMap{} + err := UpdateMoreStates(ctx, nodes, locked_nodes, nodes_fn) + if err == nil { + err = WriteNodes(ctx, locked_nodes) + } + + for _, node := range(locked_nodes) { + node.Unlock() + } + return err +} +func UpdateMoreStates(ctx * Context, nodes []Node, locked_nodes NodeMap, nodes_fn NodesFn) error { + for _, node := range(nodes) { + _, locked := locked_nodes[node.ID()] + if locked == false { + node.Lock() + locked_nodes[node.ID()] = node + } + } + + return nodes_fn(locked_nodes) +} + +func UpdateChannel(node Node, buffer int, id NodeID) chan GraphSignal { + if node == nil { + panic("Cannot get an update channel to nil") + } + new_listener := make(chan GraphSignal, buffer) + node.RegisterChannel(id, new_listener) + return new_listener +} diff --git a/signal.go b/signal.go new file mode 100644 index 0000000..73a0b45 --- /dev/null +++ b/signal.go @@ -0,0 +1,82 @@ +package graphvent + +import ( + "encoding/json" +) + +type SignalDirection int +const ( + Up SignalDirection = iota + Down + Direct +) + +// GraphSignals are passed around the event tree/resource DAG and cast by Type() +type GraphSignal interface { + // How to propogate the signal + Direction() SignalDirection + Source() NodeID + Type() string + String() string +} + +// BaseSignal is the most basic type of signal, it has no additional data +type BaseSignal struct { + FDirection SignalDirection `json:"direction"` + FSource NodeID `json:"source"` + FType string `json:"type"` +} + +func (state BaseSignal) String() string { + ser, err := json.Marshal(state) + if err != nil { + return "STATE_SER_ERR" + } + return string(ser) +} + +func (signal BaseSignal) Direction() SignalDirection { + return signal.FDirection +} + +func (signal BaseSignal) Source() NodeID { + return signal.FSource +} + +func (signal BaseSignal) Type() string { + return signal.FType +} + +func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal { + var source_id NodeID = "nil" + if source != nil { + source_id = source.ID() + } + + signal := BaseSignal{ + FDirection: direction, + FSource: source_id, + FType: _type, + } + return signal +} + +func NewDownSignal(source Node, _type string) BaseSignal { + return NewBaseSignal(source, _type, Down) +} + +func NewSignal(source Node, _type string) BaseSignal { + return NewBaseSignal(source, _type, Up) +} + +func NewDirectSignal(source Node, _type string) BaseSignal { + return NewBaseSignal(source, _type, Direct) +} + +func AbortSignal(source Node) BaseSignal { + return NewBaseSignal(source, "abort", Down) +} + +func CancelSignal(source Node) BaseSignal { + return NewBaseSignal(source, "cancel", Down) +}