diff --git a/context.go b/context.go index ed9f712..2e24869 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package graphvent import ( badger "github.com/dgraph-io/badger/v3" "fmt" + "errors" "runtime" ) @@ -86,21 +87,38 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun return nil } -// Route a Signal to dest. Currently only local context routing is supported -func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { - target, exists := ctx.Nodes[dest] +var NodeNotFoundError = errors.New("Node not found in DB") + +func (ctx *Context) GetNode(id NodeID) (*Node, error) { + target, exists := ctx.Nodes[id] if exists == false { - return fmt.Errorf("%s does not exist, cannot signal it", dest) + var err error + target, err = LoadNode(ctx, id) + if err != nil { + return nil, err + } } - select { - case target.MsgChan <- Msg{source, signal}: - default: - buf := make([]byte, 4096) - n := runtime.Stack(buf, false) - stack_str := string(buf[:n]) - return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str) + return target, nil +} + +// Route a Signal to dest. Currently only local context routing is supported +func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { + target, err := ctx.GetNode(dest) + if err == nil { + select { + case target.MsgChan <- Msg{source, signal}: + default: + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + stack_str := string(buf[:n]) + return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str) + } + return nil + } else if errors.Is(err, NodeNotFoundError) { + // TODO: Handle finding nodes in other contexts + return err } - return nil + return err } // Create a new Context with the base library content added diff --git a/gql_test.go b/gql_test.go index a166acf..e775998 100644 --- a/gql_test.go +++ b/gql_test.go @@ -34,7 +34,7 @@ func TestGQLDB(t * testing.T) { err = ctx.Send(gql.ID, gql.ID, StopSignal) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener") + (*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener") ser1, err := gql.Serialize() ser2, err := u1.Serialize() @@ -49,7 +49,7 @@ func TestGQLDB(t * testing.T) { fatalErr(t, err) err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") + (*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") } diff --git a/graph_test.go b/graph_test.go index d7a8641..7804a0d 100644 --- a/graph_test.go +++ b/graph_test.go @@ -13,11 +13,11 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status string, timeout time.Duration, str string) Signal { +func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal { timeout_channel := time.After(timeout) for true { select { - case signal := <- listener: + case signal := <- listener.Chan: if signal == nil { ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) t.Fatal(str) @@ -25,6 +25,8 @@ func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status if signal.Type() == "status" { sig, ok := signal.(StatusSignal) if ok == true { + + ctx.Log.Logf("test", "Status received: %s", sig.Status) if sig.Status == status { return signal @@ -42,10 +44,10 @@ func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status return nil } -func (t * GraphTester) CheckForNone(listener chan Signal, str string) { +func (t * GraphTester) CheckForNone(listener *ListenerExt, str string) { timeout := time.After(listner_timeout) select { - case sig := <- listener: + case sig := <- listener.Chan: pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) t.Fatal(fmt.Sprintf("%s : %+v", str, sig)) case <-timeout: diff --git a/lockable_test.go b/lockable_test.go index ed8029c..9177bb2 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -4,10 +4,35 @@ import ( "testing" ) -func TestLockableLink(t *testing.T) { +const TestLockableType = NodeType("TEST_LOCKABLE") +func lockableTestContext(t *testing.T) *Context { ctx := logTestContext(t, []string{"lockable", "signal"}) - LockableType := NodeType("TEST_LOCKABLE") - err := ctx.RegisterNodeType(LockableType, []ExtType{LockableExtType}) + + err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType, ListenerExtType}) fatalErr(t, err) + + return ctx +} + + +var link_policy = NewAllNodesPolicy([]string{"link", "status"}) + +func Test(t *testing.T) { + ctx := lockableTestContext(t) + + l1_listener := NewListenerExt(10) + l1 := NewNode(ctx, RandID(), TestLockableType, nil, + l1_listener, + NewACLExt(&link_policy), + NewLockableExt(nil, nil, nil, nil), + ) + l2_listener := NewListenerExt(10) + l2 := NewNode(ctx, RandID(), TestLockableType, nil, + l2_listener, + NewACLExt(&link_policy), + NewLockableExt(nil, nil, nil, nil), + ) + + ctx.Send(l1.ID, l2.ID, NewLinkSignal("start", l1.ID)) } diff --git a/node.go b/node.go index 2289283..d4d0621 100644 --- a/node.go +++ b/node.go @@ -2,6 +2,7 @@ package graphvent import ( "time" + "errors" "reflect" "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" @@ -458,14 +459,7 @@ func WriteNode(ctx *Context, node *Node) error { } func LoadNode(ctx * Context, id NodeID) (*Node, error) { - ctx.Log.Logf("db", "LOOKING_FOR_NODE: %s", id) - node, exists := ctx.Nodes[id] - if exists == true { - ctx.Log.Logf("db", "NODE_ALREADY_LOADED: %s", id) - return node,nil - } ctx.Log.Logf("db", "LOADING_NODE: %s", id) - var bytes []byte err := ctx.DB.View(func(txn *badger.Txn) error { item, err := txn.Get(id.Serialize()) @@ -478,7 +472,9 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil }) }) - if err != nil { + if errors.Is(err, badger.ErrKeyNotFound) { + return nil, NodeNotFoundError + }else if err != nil { return nil, err } @@ -494,7 +490,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { } next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals) - node = &Node{ + node := &Node{ ID: id, Type: node_type.Type, Extensions: map[ExtType]Extension{},