diff --git a/gql.go b/gql.go index 658abcc..f2beb47 100644 --- a/gql.go +++ b/gql.go @@ -809,7 +809,10 @@ var gql_actions ThreadActions = ThreadActions{ "restore": func(ctx * Context, thread Thread) (string, error) { // Start all the threads that should be "started" ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID()) - ThreadRestore(ctx, thread) + err := ThreadRestore(ctx, thread) + if err != nil { + return "", err + } return "start_server", nil }, "start": func(ctx * Context, thread Thread) (string, error) { @@ -819,6 +822,21 @@ var gql_actions ThreadActions = ThreadActions{ return "", err } + // Start all the threads that have "Start" as true + context := NewWriteContext(ctx) + err = UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { + return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error { + for _, child := range(thread.Children()) { + info := thread.ChildInfo(child.ID()).(ParentInfo).Parent() + if info.Start == true { + ctx.Log.Logf("thread", "THREAD_START_CHILD: %s -> %s", thread.ID(), child.ID()) + ChildGo(ctx, thread, child, info.StartAction) + } + } + return nil + }) + }) + return "start_server", nil }, "start_server": func(ctx * Context, thread Thread) (string, error) { @@ -956,6 +974,6 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ return "wait", nil }, "abort": ThreadAbort, - "cancel": ThreadCancel, + "stop": ThreadStop, } diff --git a/gql_test.go b/gql_test.go index 2bb5c67..a48398e 100644 --- a/gql_test.go +++ b/gql_test.go @@ -3,7 +3,6 @@ package graphvent import ( "testing" "time" - "errors" "net" "net/http" "io" @@ -19,13 +18,17 @@ import ( ) func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"policy", "mutex"}) + ctx := logTestContext(t, []string{"test", "signal", "thread"}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r + ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1 := &t1_r - update_channel := UpdateChannel(t1, 10, NodeID{}) + ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String()) + listen_id := RandID() + ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String()) + update_channel := UpdateChannel(t1, 10, listen_id) u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) @@ -34,28 +37,52 @@ func TestGQLDBLoad(t * testing.T) { u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"}) u1 := &u1_r - - p1_r := NewSimplePolicy(RandID(), NewNodeActions(nil, []string{"enumerate"})) - p1 := &p1_r + ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String()) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql := &gql_r + ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String()) + + // Policy to allow gql to perform all action on all resources + p1_r := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ + gql.ID(): NewNodeActions(nil, []string{"*"}), + }) + p1 := &p1_r + p2_r := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{ + "signal": []string{"read"}, + }, nil)) + p2 := &p2_r - info := NewParentThreadInfo(true, "start", "restore") context := NewWriteContext(ctx) + err = UpdateStates(context, gql, LockMap{ + p1.ID(): LockInfo{p1, nil}, + p2.ID(): LockInfo{p2, nil}, + }, func(context *StateContext) error { + return nil + }) + fatalErr(t, err) + + ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String()) + ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String()) + err = AttachPolicies(ctx, gql, p1, p2) + fatalErr(t, err) + err = AttachPolicies(ctx, l1, p1, p2) + fatalErr(t, err) + err = AttachPolicies(ctx, t1, p1, p2) + fatalErr(t, err) + err = AttachPolicies(ctx, u1, p1, p2) + fatalErr(t, err) + + info := NewParentThreadInfo(true, "start", "restore") + context = NewWriteContext(ctx) err = UpdateStates(context, gql, NewLockMap( - NewLockInfo(gql, []string{"policies", "users"}), + NewLockInfo(gql, []string{"users"}), ), func(context *StateContext) error { - err := gql.AddPolicy(p1) - if err != nil { - return err - } - gql.Users[KeyID(&u1_key.PublicKey)] = u1 - err = LinkThreads(context, gql, gql, t1, &info) + err := LinkThreads(context, gql, gql, t1, &info) if err != nil { return err } @@ -69,20 +96,14 @@ func TestGQLDBLoad(t * testing.T) { if err != nil { return nil } - return gql.Signal(context, CancelSignal) + return gql.Signal(context, StopSignal) }) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") - if errors.Is(err, ThreadAbortedError) { - ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) - } else if err != nil{ - fatalErr(t, err) - } else { - ctx.Log.Logf("test", "Main thread cancelled by signal") - } + fatalErr(t, err) - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel") context = NewReadContext(ctx) err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { @@ -105,28 +126,24 @@ func TestGQLDBLoad(t * testing.T) { u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child - update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) + update_channel_2 = UpdateChannel(t1_loaded, 10, RandID()) err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error { ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) - gql_loaded.Signal(context, AbortSignal) + gql_loaded.Signal(context, StopSignal) return err }) - err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") - if errors.Is(err, ThreadAbortedError) { - ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) - } else { - fatalErr(t, err) - } - (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") + err = ThreadLoop(ctx, gql_loaded.(Thread), "start") + fatalErr(t, err) + (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") } func TestGQLAuth(t * testing.T) { - ctx := logTestContext(t, []string{"policy", "mutex"}) + ctx := logTestContext(t, []string{"policy"}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) @@ -162,7 +179,7 @@ func TestGQLAuth(t * testing.T) { } context := NewReadContext(ctx) err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error { - return thread.Signal(context, CancelSignal) + return thread.Signal(context, StopSignal) }) fatalErr(t, err) }(done, gql_t) diff --git a/node.go b/node.go index fc95251..35a8977 100644 --- a/node.go +++ b/node.go @@ -108,7 +108,7 @@ func (node *GraphNode) Allowed(action string, resource string, principal Node) e return nil } for _, policy := range(node.policies) { - if policy.Allows(action, resource, principal) == true { + if policy.Allows(resource, action, principal) == true { return nil } } @@ -200,11 +200,11 @@ func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error closed := []NodeID{} for id, listener := range node.listeners { - context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) + context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %s", node.ID(), id) select { case listener <- signal: default: - context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) + context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %s", node.ID(), id) go func(node Node, listener chan GraphSignal) { listener <- NewDirectSignal("listener_closed") close(listener) @@ -239,6 +239,19 @@ func (node * GraphNode) UnregisterChannel(id NodeID) { node.listeners_lock.Unlock() } +func AttachPolicies(ctx *Context, node Node, policies ...Policy) error { + context := NewWriteContext(ctx) + return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error { + for _, policy := range(policies) { + err := node.AddPolicy(policy) + if err != nil { + return err + } + } + return nil + }) +} + func NewGraphNode(id NodeID) GraphNode { return GraphNode{ id: id, @@ -577,16 +590,16 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St if already_granted == false { err := node.Allowed("read", resource, princ) if err != nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.read", princ.ID().String(), id.String(), resource) for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) n.RUnlock() } return err } - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.read", princ.ID().String(), id.String(), resource) } else { - context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.read", princ.ID().String(), id.String(), resource) } } new_permissions[id] = node_permissions @@ -681,16 +694,16 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn if already_granted == false { err := node.Allowed("write", resource, princ) if err != nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.write", princ.ID().String(), id.String(), resource) for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) n.Unlock() } return err } - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.write", princ.ID().String(), id.String(), resource) } else { - context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) + context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.write", princ.ID().String(), id.String(), resource) } } new_permissions[id] = node_permissions diff --git a/signal.go b/signal.go index 71cce2c..81446c1 100644 --- a/signal.go +++ b/signal.go @@ -62,7 +62,7 @@ func NewDirectSignal(_type string) BaseSignal { } var AbortSignal = NewBaseSignal("abort", Down) -var CancelSignal = NewBaseSignal("cancel", Down) +var StopSignal = NewBaseSignal("stop", Down) type IDSignal struct { BaseSignal diff --git a/thread.go b/thread.go index 3328f7f..a0688d1 100644 --- a/thread.go +++ b/thread.go @@ -595,21 +595,20 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro // Helper function to restore threads that should be running from a parents restore action // Starts a write context, so cannot be called from either a write or read context -func ThreadRestore(ctx * Context, thread Thread) { +func ThreadRestore(ctx * Context, thread Thread) error { context := NewWriteContext(ctx) - UpdateStates(context, thread, NewLockMap( - NewLockInfo(thread, []string{"children"}), - LockList(thread.Children(), []string{"start"}), - ), func(context *StateContext)(error) { - for _, child := range(thread.Children()) { - should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() - ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) - if should_run.Start == true && child.State() != "finished" { - ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID()) - ChildGo(ctx, thread, child, should_run.RestoreAction) + return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { + return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error { + for _, child := range(thread.Children()) { + should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() + ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) + if should_run.Start == true && child.State() != "finished" { + ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID()) + ChildGo(ctx, thread, child, should_run.RestoreAction) + } } - } - return nil + return nil + }) }) } @@ -697,14 +696,14 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro if err != nil { return "", err } - return "finish", ThreadAbortedError + return "", ThreadAbortedError } -// Default thread action for "cancel", sends a signal and returns no error -func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { +// Default thread action for "stop", sends a signal and returns no error +func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { context := NewReadContext(ctx) err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { - return thread.Signal(context, NewSignal("cancelled")) + return thread.Signal(context, NewSignal("stopped")) }) return "finish", err } @@ -740,5 +739,5 @@ var BaseThreadActions = ThreadActions{ // Default thread signal handlers var BaseThreadHandlers = ThreadHandlers{ "abort": ThreadAbort, - "cancel": ThreadCancel, + "stop": ThreadStop, }