From 8fb0cbc982a3d0a460182e763e0e162e66bb1d37 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 23 Jul 2023 17:57:47 -0600 Subject: [PATCH] Reworked use/update to require a read/write context be initialized before starting, still need to figure out if brittle locking is the solution to potential deadlock, and implement if so --- gql.go | 42 +++--- gql_mutation.go | 14 +- gql_resolvers.go | 40 +++--- gql_test.go | 69 +++------- graph_test.go | 3 +- lockable.go | 73 +++++----- node.go | 336 ++++++++++++++++++++++++++++++----------------- thread.go | 112 +++++++++------- 8 files changed, 391 insertions(+), 298 deletions(-) diff --git a/gql.go b/gql.go index 38ae8d4..658abcc 100644 --- a/gql.go +++ b/gql.go @@ -204,7 +204,8 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"}) - err := UpdateStates(ctx, server, NewLockMap(LockMap{ + context := NewWriteContext(ctx) + err := UpdateStates(context, server, NewLockMap(LockMap{ server.ID(): LockInfo{ Node: server, Resources: []string{"users"}, @@ -213,7 +214,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt Node: &new_user, Resources: []string{""}, }, - }), func(context *WriteContext) error { + }), func(context *StateContext) error { server.Users[key_id] = &new_user return nil }) @@ -873,9 +874,10 @@ var gql_actions ThreadActions = ThreadActions{ }(server) - err = UpdateStates(ctx, server, NewLockMap( + context := NewWriteContext(ctx) + err = UpdateStates(context, server, NewLockMap( NewLockInfo(server, []string{"http_server"}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { server.tcp_listener = listener server.http_server = http_server return nil @@ -885,9 +887,10 @@ var gql_actions ThreadActions = ThreadActions{ return "", err } - err = UseStates(ctx, server, NewLockMap( + context = NewReadContext(ctx) + err = UseStates(context, server, NewLockMap( NewLockInfo(server, []string{"signal"}), - ), func(context *ReadContext) error { + ), func(context *StateContext) error { return server.Signal(context, NewSignal("server_started")) }) @@ -897,14 +900,21 @@ var gql_actions ThreadActions = ThreadActions{ return "wait", nil }, + "finish": func(ctx *Context, thread Thread) (string, error) { + server := thread.(*GQLThread) + server.http_server.Shutdown(context.TODO()) + server.http_done.Wait() + return "", ThreadFinish(ctx, thread) + }, } var gql_handlers ThreadHandlers = ThreadHandlers{ "child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) - err := UpdateStates(ctx, thread, NewLockMap( + context := NewWriteContext(ctx) + err := UpdateStates(context, thread, NewLockMap( NewLockInfo(thread, []string{"children"}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { sig, ok := signal.(IDSignal) if ok == false { ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST") @@ -945,19 +955,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ return "wait", nil }, - "abort": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - ctx.Log.Logf("gql", "GQL_ABORT") - server := thread.(*GQLThread) - server.http_server.Shutdown(context.TODO()) - server.http_done.Wait() - return ThreadAbort(ctx, thread, signal) - }, - "cancel": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - ctx.Log.Logf("gql", "GQL_CANCEL") - server := thread.(*GQLThread) - server.http_server.Shutdown(context.TODO()) - server.http_done.Wait() - return ThreadCancel(ctx, thread, signal) - }, + "abort": ThreadAbort, + "cancel": ThreadCancel, } diff --git a/gql_mutation.go b/gql_mutation.go index d5d2354..c39cff6 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -29,14 +29,15 @@ var GQLMutationAbort = NewField(func()*graphql.Field { } var node Node = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap( + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockMap( NewLockInfo(ctx.Server, []string{"children"}), - ), func(context *ReadContext) (error){ + ), func(context *StateContext) (error){ node = FindChild(context, ctx.User, ctx.Server, id) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"signal"})), func(context *ReadContext) error { + return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error { return node.Signal(context, AbortSignal) }) }) @@ -88,9 +89,10 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } var signal GraphSignal - err = UseStates(ctx.Context, ctx.User, NewLockMap( + context := NewWriteContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockMap( NewLockInfo(ctx.Server, []string{"children"}), - ), func(context *ReadContext) error { + ), func(context *StateContext) error { node := FindChild(context, ctx.User, ctx.Server, parent_id) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) @@ -101,7 +103,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ return err } - return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *ReadContext) error { + return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error { signal = NewStartChildSignal(child_id, action) return node.Signal(context, signal) }) diff --git a/gql_resolvers.go b/gql_resolvers.go index dcc0f62..4421ae6 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -42,24 +42,13 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { return id, nil } +// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - node, ok := p.Source.(Node) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Node") } - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"id"})), func(context *ReadContext) error { - return nil - }) - if err != nil { - return nil, err - } - return node.ID(), nil } @@ -76,7 +65,8 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { listen := "" - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"listen"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error { listen = node.Listen return nil }) @@ -100,7 +90,8 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { } var parent Thread = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"parent"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error { parent = node.Parent() return nil }) @@ -124,7 +115,8 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { } var state string - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"state"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { state = node.State() return nil }) @@ -148,7 +140,8 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { } var children []Thread = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"children"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { children = node.Children() return nil }) @@ -172,7 +165,8 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { } name := "" - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"name"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error { name = node.Name() return nil }) @@ -196,7 +190,8 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { } var requirements []Lockable = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"requirements"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error { requirements = node.Requirements() return nil }) @@ -220,7 +215,8 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { } var dependencies []Lockable = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"dependencies"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error { dependencies = node.Dependencies() return nil }) @@ -244,7 +240,8 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { } var owner Node = nil - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"owner"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error { owner = node.Owner() return nil }) @@ -268,7 +265,8 @@ func GQLThreadUsers(p graphql.ResolveParams) (interface{}, error) { } var users []*User - err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"users"})), func(context *ReadContext) error { + context := NewReadContext(ctx.Context) + err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error { users = make([]*User, len(node.Users)) i := 0 for _, user := range(node.Users) { diff --git a/gql_test.go b/gql_test.go index c6d4d23..2bb5c67 100644 --- a/gql_test.go +++ b/gql_test.go @@ -18,48 +18,8 @@ import ( "encoding/base64" ) -func TestGQLThread(t * testing.T) { - ctx := logTestContext(t, []string{}) - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - - gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) - gql_t := &gql_t_r - - t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) - t2 := &t2_r - - err = UpdateStates(ctx, gql_t, NewLockMap( - LockList([]Node{t1, t2}, []string{"parent"}), - NewLockInfo(gql_t, []string{"children"}), - ), func(context *WriteContext) error { - i1 := NewParentThreadInfo(true, "start", "restore") - err := LinkThreads(context, gql_t, gql_t, t1, &i1) - if err != nil { - return err - } - - i2 := NewParentThreadInfo(false, "start", "restore") - return LinkThreads(context, gql_t, gql_t, t2, &i2) - }) - fatalErr(t, err) - - go func(thread Thread){ - time.Sleep(10*time.Millisecond) - err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error { - return thread.Signal(context, CancelSignal) - }) - fatalErr(t, err) - }(gql_t) - - err = ThreadLoop(ctx, gql_t, "start") - fatalErr(t, err) -} - func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{}) + ctx := logTestContext(t, []string{"policy", "mutex"}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -84,9 +44,10 @@ func TestGQLDBLoad(t * testing.T) { gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") - err = UpdateStates(ctx, gql, NewLockMap( + context := NewWriteContext(ctx) + err = UpdateStates(context, gql, NewLockMap( NewLockInfo(gql, []string{"policies", "users"}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { err := gql.AddPolicy(p1) if err != nil { return err @@ -102,7 +63,8 @@ func TestGQLDBLoad(t * testing.T) { }) fatalErr(t, err) - err = UseStates(ctx, gql, NewLockInfo(gql, []string{"signal"}), func(context *ReadContext) error { + context = NewReadContext(ctx) + err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error { err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID())) if err != nil { return nil @@ -122,7 +84,8 @@ func TestGQLDBLoad(t * testing.T) { (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") - err = UseStates(ctx, gql, LockList([]Node{gql, u1}, nil), func(context *ReadContext) error { + context = NewReadContext(ctx) + err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -135,14 +98,15 @@ func TestGQLDBLoad(t * testing.T) { var t1_loaded *SimpleThread = nil var update_channel_2 chan GraphSignal - err = UseStates(ctx, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *ReadContext) error { + context = NewReadContext(ctx) + err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error { ser, err := gql_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) - err = UseMoreStates(context, gql, NewLockInfo(u_loaded, nil), func(context *ReadContext) error { + err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error { ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err @@ -162,7 +126,7 @@ func TestGQLDBLoad(t * testing.T) { } func TestGQLAuth(t * testing.T) { - ctx := logTestContext(t, []string{"test", "gql"}) + ctx := logTestContext(t, []string{"policy", "mutex"}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) @@ -173,14 +137,16 @@ func TestGQLAuth(t * testing.T) { gql_t := &gql_t_r // p1 not written to DB, TODO: update write to follow links maybe - err = UpdateStates(ctx, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *WriteContext) error { + context := NewWriteContext(ctx) + err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error { return gql_t.AddPolicy(p1) }) done := make(chan error, 1) var update_channel chan GraphSignal - err = UseStates(ctx, gql_t, NewLockInfo(gql_t, nil), func(context *ReadContext) error { + context = NewReadContext(ctx) + err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error { update_channel = UpdateChannel(gql_t, 10, NodeID{}) return nil }) @@ -194,7 +160,8 @@ func TestGQLAuth(t * testing.T) { case <-done: ctx.Log.Logf("test", "DONE") } - err := UseStates(ctx, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *ReadContext) error { + context := NewReadContext(ctx) + err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error { return thread.Signal(context, CancelSignal) }) fatalErr(t, err) diff --git a/graph_test.go b/graph_test.go index 54a7521..6f440e0 100644 --- a/graph_test.go +++ b/graph_test.go @@ -5,6 +5,7 @@ import ( "fmt" "time" "runtime/pprof" + "runtime/debug" "os" badger "github.com/dgraph-io/badger/v3" ) @@ -63,7 +64,7 @@ func testContext(t * testing.T) * Context { func fatalErr(t * testing.T, err error) { if err != nil { - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + debug.PrintStack() t.Fatal(err) } } diff --git a/lockable.go b/lockable.go index 2df2cc5..b5d73de 100644 --- a/lockable.go +++ b/lockable.go @@ -216,7 +216,7 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { } // Assumed that lockable is already locked for signal -func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal) error { +func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error { err := lockable.GraphNode.Signal(context, signal) if err != nil { return err @@ -224,10 +224,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal switch signal.Direction() { case Up: - err = UseMoreStates(context, lockable, NewLockMap( + err = UseStates(context, lockable, NewLockMap( NewLockInfo(lockable, []string{"dependencies", "owner"}), LockList(lockable.requirements, []string{"signal"}), - ), func(context *ReadContext) error { + ), func(context *StateContext) error { owner_sent := false for _, dependency := range(lockable.dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) @@ -241,7 +241,7 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal if lockable.owner != nil && owner_sent == false { if lockable.owner.ID() != lockable.ID() { context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - return UseMoreStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *ReadContext) error { + return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error { return lockable.owner.Signal(context, signal) }) } @@ -249,10 +249,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal return nil }) case Down: - err = UseMoreStates(context, lockable, NewLockMap( + err = UseStates(context, lockable, NewLockMap( NewLockInfo(lockable, []string{"requirements"}), LockList(lockable.requirements, []string{"signal"}), - ), func(context *ReadContext) error { + ), func(context *StateContext) error { for _, requirement := range(lockable.requirements) { err := requirement.Signal(context, signal) if err != nil { @@ -270,29 +270,36 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal } // Removes requirement as a requirement from lockable -// Requires lockable and requirement be locked for write -func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) error { - var found Node = nil - for _, req := range(lockable.Requirements()) { - if requirement.ID() == req.ID() { - found = req - break +// Continues the write context with princ, getting requirents for lockable and dependencies for requirement +// Assumes that an active write context exists with princ locked so that princ's state can be used in checks +func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error { + return UpdateStates(context, princ, LockMap{ + lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}}, + requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}}, + }, func(context *StateContext) error { + var found Node = nil + for _, req := range(lockable.Requirements()) { + if requirement.ID() == req.ID() { + found = req + break + } } - } - if found == nil { - return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) - } + if found == nil { + return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) + } - requirement.RemoveDependency(lockable) - lockable.RemoveRequirement(requirement) + requirement.RemoveDependency(lockable) + lockable.RemoveRequirement(requirement) - return nil + return nil + }) } // Link requirements as requirements to lockable -// Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks -func LinkLockables(context *WriteContext, princ Node, lockable Lockable, requirements []Lockable) error { +// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements +// Assumes that an active write context exists with princ locked so that princ's state can be used in checks +func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error { if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } @@ -318,10 +325,10 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require found[requirement.ID()] = true } - return UpdateMoreStates(context, princ, NewLockMap( + return UpdateStates(context, princ, NewLockMap( NewLockInfo(lockable, []string{"requirements"}), LockList(requirements, []string{"dependencies"}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { // Check that all the requirements can be added // If the lockable is already locked, need to lock this resource as well before we can add it for _, requirement := range(requirements) { @@ -365,13 +372,13 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require } // Must be called withing update context -func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool { +func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool { for _, c := range(cur.Requirements()) { if c.ID() == r.ID() { return true } is_requirement := false - UpdateMoreStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *WriteContext) error { + UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error { is_requirement = checkIfRequirement(context, cur, c) return nil }) @@ -386,7 +393,7 @@ func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool { // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call -func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable) error { +func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } @@ -406,10 +413,10 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable return nil } - return UpdateMoreStates(context, new_owner, NewLockMap( + return UpdateStates(context, new_owner, NewLockMap( LockList(to_lock, []string{"lock"}), NewLockInfo(new_owner, []string{}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements for _, req := range(to_lock) { context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) @@ -426,7 +433,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable if owner.ID() == new_owner.ID() { continue } else { - err := UpdateMoreStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *WriteContext)(error){ + err := UpdateStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *StateContext)(error){ return LockLockables(context, req.Requirements(), req) }) if err != nil { @@ -464,7 +471,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable } -func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lockable) error { +func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } @@ -484,10 +491,10 @@ func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lock return nil } - return UpdateMoreStates(context, old_owner, NewLockMap( + return UpdateStates(context, old_owner, NewLockMap( LockList(to_unlock, []string{"lock"}), NewLockInfo(old_owner, []string{}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements for _, req := range(to_unlock) { context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) diff --git a/node.go b/node.go index 8625c22..fc95251 100644 --- a/node.go +++ b/node.go @@ -74,7 +74,7 @@ type Node interface { RemovePolicy(Policy) error // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Signal(context *ReadContext, signal GraphSignal) error + Signal(context *StateContext, signal GraphSignal) error // Register a channel to receive updates sent to the node RegisterChannel(id NodeID, listener chan GraphSignal) // Unregister a channel from receiving updates sent to the node @@ -193,7 +193,7 @@ func (node * GraphNode) Type() NodeType { // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it -func (node * GraphNode) Signal(context *ReadContext, signal GraphSignal) error { +func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error { context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) node.listeners_lock.Lock() defer node.listeners_lock.Unlock() @@ -293,20 +293,18 @@ func getNodeBytes(node Node) ([]byte, error) { } // Write multiple nodes to the database in a single transaction -func WriteNodes(context *WriteContext) error { - if context == nil { - return fmt.Errorf("Cannot write nil to DB") - } - if context.Locked == nil { - return fmt.Errorf("Cannot write nil map to DB") +func WriteNodes(context *StateContext) error { + err := ValidateStateContext(context, "write", true) + if err != nil { + return err } + context.Graph.Log.Logf("db", "DB_WRITES: %d", len(context.Locked)) serialized_bytes := make([][]byte, len(context.Locked)) serialized_ids := make([][]byte, len(context.Locked)) i := 0 - for _, lock := range(context.Locked) { - node := lock.Node + for _, node := range(context.Locked) { node_bytes, err := getNodeBytes(node) context.Graph.Log.Logf("db", "DB_WRITE: %+v", node) if err != nil { @@ -321,7 +319,7 @@ func WriteNodes(context *WriteContext) error { i++ } - err := context.Graph.DB.Update(func(txn *badger.Txn) error { + return context.Graph.DB.Update(func(txn *badger.Txn) error { for i, id := range(serialized_ids) { err := txn.Set(id, serialized_bytes[i]) if err != nil { @@ -330,8 +328,6 @@ func WriteNodes(context *WriteContext) error { } return nil }) - - return err } // Get the bytes associates with `id` from the database after unwrapping the header, or error @@ -450,17 +446,54 @@ type LockInfo struct { type LockMap map[NodeID]LockInfo -type ReadContext struct { +type StateContext struct { + Type string Graph *Context - Locked LockMap + Permissions map[NodeID]LockMap + Locked NodeMap + Started bool + Finished bool } -type ReadFn func(*ReadContext)(error) -type WriteContext struct { - Graph *Context - Locked LockMap +func ValidateStateContext(context *StateContext, Type string, Finished bool) error { + if context == nil { + return fmt.Errorf("context is nil") + } + if context.Finished != Finished { + return fmt.Errorf("context in wrong Finished state") + } + if context.Type != Type { + return fmt.Errorf("%s is not a %s context", context.Type, Type) + } + if context.Locked == nil || context.Graph == nil || context.Permissions == nil { + return fmt.Errorf("context is not initialized correctly") + } + return nil +} + +func NewReadContext(ctx *Context) *StateContext { + return &StateContext{ + Type: "read", + Graph: ctx, + Permissions: map[NodeID]LockMap{}, + Locked: NodeMap{}, + Started: false, + Finished: false, + } +} + +func NewWriteContext(ctx *Context) *StateContext { + return &StateContext{ + Type: "write", + Graph: ctx, + Permissions: map[NodeID]LockMap{}, + Locked: NodeMap{}, + Started: false, + Finished: false, + } } -type WriteFn func(*WriteContext)(error) + +type StateFn func(*StateContext)(error) func del[K comparable](list []K, val K) []K { idx := -1 @@ -478,146 +511,211 @@ func del[K comparable](list []K, val K) []K { return list[:len(list)-1] } -// Start a read context for node under ctx for the resources specified in init_nodes, then run nodes_fn -func UseStates(ctx *Context, node Node, nodes LockMap, read_fn ReadFn) error { - context := &ReadContext{ - Graph: ctx, - Locked: LockMap{}, - } - return UseMoreStates(context, node, nodes, read_fn) -} - // Add nodes to an existing read context and call nodes_fn with new_nodes locked for read // Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read -func UseMoreStates(context *ReadContext, node Node, new_nodes LockMap, read_fn ReadFn) error { - locked_nodes := []Node{} +func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { + if princ == nil || new_nodes == nil || state_fn == nil { + return fmt.Errorf("nil passed to UseStates") + } + + err := ValidateStateContext(context, "read", false) + if err != nil { + return err + } + + final := false + if context.Started == false { + context.Started = true + final = true + } + + new_locks := []Node{} + _, princ_locked := context.Locked[princ.ID()] + if princ_locked == false { + new_locks = append(new_locks, princ) + context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String()) + princ.RLock() + } + + princ_permissions, princ_exists := context.Permissions[princ.ID()] new_permissions := LockMap{} - for _, request := range(new_nodes) { - id := request.Node.ID() - new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} - for _, resource := range(request.Resources) { - // If the permission for this resource is already granted, continue - current_permissions, exists := context.Locked[id] - if exists == true { - already_granted := false - for _, r := range(current_permissions.Resources) { - if r == resource { - already_granted = true - break - } - } - if already_granted == true { - continue - } - } + if princ_exists == true { + for id, info := range(princ_permissions) { + new_permissions[id] = info + } + } - err := request.Node.Allowed("read", resource, node) - if err != nil { - return err + for _, request := range(new_nodes) { + node := request.Node + if node == nil { + return fmt.Errorf("node in request list is nil") + } + id := node.ID() + + if id != princ.ID() { + _, locked := context.Locked[id] + if locked == false { + new_locks = append(new_locks, node) + context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) + node.RLock() } + } - tmp := new_permissions[id] - tmp.Resources = append(tmp.Resources, resource) - new_permissions[id] = tmp + node_permissions, node_exists := new_permissions[id] + if node_exists == false { + node_permissions = LockInfo{Node: node, Resources: []string{}} } + for _, resource := range(request.Resources) { + already_granted := false + for _, granted := range(node_permissions.Resources) { + if resource == granted { + already_granted = true + } + } - req_perms, exists := new_permissions[id] - if exists == true { - cur_perms, already_locked := context.Locked[id] - if already_locked == false { - request.Node.RLock() - context.Locked[id] = req_perms - locked_nodes = append(locked_nodes, request.Node) + if already_granted == false { + err := node.Allowed("read", resource, princ) + if err != nil { + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) + for _, n := range(new_locks) { + context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) + n.RUnlock() + } + return err + } + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) } else { - cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) + context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) } } + new_permissions[id] = node_permissions } - err := read_fn(context) - - for _, request := range(new_permissions) { - cur_perms := context.Locked[request.Node.ID()] - new_perms := cur_perms.Resources - for _, resource := range(cur_perms.Resources) { - new_perms = del(new_perms, resource) - } - cur_perms.Resources = new_perms - context.Locked[request.Node.ID()] = cur_perms + for _, node := range(new_locks) { + context.Locked[node.ID()] = node } - for _, node := range(locked_nodes) { + context.Permissions[princ.ID()] = new_permissions + + err = state_fn(context) + + context.Permissions[princ.ID()] = princ_permissions + + for _, node := range(new_locks) { + context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String()) delete(context.Locked, node.ID()) node.RUnlock() } + if final == true { + context.Finished = true + } + return err } -// Initiate a write context for nodes and call nodes_fn with nodes locked for read -func UpdateStates(ctx *Context, node Node, nodes LockMap, write_fn WriteFn) error { - context := &WriteContext{ - Graph: ctx, - Locked: LockMap{}, +// Add nodes to an existing write context and call nodes_fn with nodes locked for read +// If context is nil +func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { + if princ == nil || new_nodes == nil || state_fn == nil { + return fmt.Errorf("nil passed to UpdateStates") } - err := UpdateMoreStates(context, node, nodes, write_fn) - if err == nil { - err = WriteNodes(context) + + err := ValidateStateContext(context, "write", false) + if err != nil { + return err } - for _, lock := range(context.Locked) { - lock.Node.Unlock() + final := false + if context.Started == false { + context.Started = true + final = true } - return err -} + new_locks := []Node{} + _, princ_locked := context.Locked[princ.ID()] + if princ_locked == false { + new_locks = append(new_locks, princ) + context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String()) + princ.Lock() + } -// Add nodes to an existing write context and call nodes_fn with nodes locked for read -func UpdateMoreStates(context *WriteContext, node Node, new_nodes LockMap, write_fn WriteFn) error { + princ_permissions, princ_exists := context.Permissions[princ.ID()] new_permissions := LockMap{} + if princ_exists == true { + for id, info := range(princ_permissions) { + new_permissions[id] = info + } + } + for _, request := range(new_nodes) { - id := request.Node.ID() - new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} + node := request.Node + if node == nil { + return fmt.Errorf("node in request list is nil") + } + id := node.ID() + + if id != princ.ID() { + _, locked := context.Locked[id] + if locked == false { + new_locks = append(new_locks, node) + context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) + node.Lock() + } + } + + node_permissions, node_exists := new_permissions[id] + if node_exists == false { + node_permissions = LockInfo{Node: node, Resources: []string{}} + } + for _, resource := range(request.Resources) { - current_permissions, exists := context.Locked[id] - if exists == true { - already_granted := false - for _, r := range(current_permissions.Resources) { - if r == resource { - already_granted = true - break - } - } - if already_granted == true { - continue + already_granted := false + for _, granted := range(node_permissions.Resources) { + if resource == granted { + already_granted = true } } - err := request.Node.Allowed("write", resource, node) - if err != nil { - return err + if already_granted == false { + err := node.Allowed("write", resource, princ) + if err != nil { + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) + for _, n := range(new_locks) { + context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) + n.Unlock() + } + return err + } + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) + } else { + context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) } - - tmp := new_permissions[id] - tmp.Resources = append(tmp.Resources, resource) - new_permissions[id] = tmp } + new_permissions[id] = node_permissions + } - req_perms, exists := new_permissions[id] - if exists == true { - cur_perms, already_locked := context.Locked[id] - if already_locked == false { - request.Node.Lock() - context.Locked[id] = req_perms - } else { - cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) - context.Locked[id] = cur_perms - } + for _, node := range(new_locks) { + context.Locked[node.ID()] = node + } + + context.Permissions[princ.ID()] = new_permissions + + err = state_fn(context) + + if final == true { + context.Finished = true + if err == nil { + err = WriteNodes(context) + } + for id, node := range(context.Locked) { + context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) + node.Unlock() } } - return write_fn(context) + return err } // Create a new channel with a buffer the size of buffer, and register it to node with the id diff --git a/thread.go b/thread.go index 4959dd2..3328f7f 100644 --- a/thread.go +++ b/thread.go @@ -10,7 +10,7 @@ import ( ) // Assumed that thread is already locked for signal -func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) error { +func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error { err := thread.SimpleLockable.Signal(context, signal) if err != nil { return err @@ -18,9 +18,9 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err switch signal.Direction() { case Up: - err = UseMoreStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *ReadContext) error { + err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { if thread.parent != nil { - return UseMoreStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *ReadContext) error { + return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error { return thread.parent.Signal(context, signal) }) } else { @@ -28,10 +28,10 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err } }) case Down: - err = UseMoreStates(context, thread, NewLockMap( + err = UseStates(context, thread, NewLockMap( NewLockInfo(thread, []string{"children"}), LockList(thread.children, []string{"signal"}), - ), func(context *ReadContext) error { + ), func(context *StateContext) error { for _, child := range(thread.children) { err := child.Signal(context, signal) if err != nil { @@ -169,15 +169,15 @@ func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error { return nil } -func checkIfChild(context *WriteContext, target Thread, cur Thread) bool { +func checkIfChild(context *StateContext, target Thread, cur Thread) bool { for _, child := range(cur.Children()) { if child.ID() == target.ID() { return true } is_child := false - UpdateMoreStates(context, cur, NewLockMap( + UpdateStates(context, cur, NewLockMap( NewLockInfo(child, []string{"children"}), - ), func(context *WriteContext) error { + ), func(context *StateContext) error { is_child = checkIfChild(context, target, child) return nil }) @@ -189,7 +189,9 @@ func checkIfChild(context *WriteContext, target Thread, cur Thread) bool { return false } -func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { +// Links child to parent with info as the associated info +// Continues the write context with princ, getting children for thread and parent for child +func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { if context == nil || thread == nil || child == nil { return fmt.Errorf("invalid input") } @@ -198,7 +200,10 @@ func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) } - return UpdateMoreStates(context, princ, LockList([]Node{child, thread}, []string{"parent", "children"}), func(context *WriteContext) error { + return UpdateStates(context, princ, LockMap{ + child.ID(): LockInfo{Node: child, Resources: []string{"parent"}}, + thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}}, + }, func(context *StateContext) error { if child.Parent() != nil { return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) } @@ -449,7 +454,7 @@ func NewSimpleThread(id NodeID, name string, state_name string, info_type reflec } // Requires the read permission of threads children -func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Thread { +func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread { if thread == nil { panic("cannot recurse through nil") } @@ -459,7 +464,7 @@ func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Threa for _, child := range thread.Children() { var result Thread = nil - UseMoreStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *ReadContext) error { + UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error { result = FindChild(context, princ, child, id) return nil }) @@ -485,7 +490,7 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) { }(child) } -// Main Loop for Threads +// Main Loop for Threads, starts a write context, so cannot be called from a write or read context func ThreadLoop(ctx * Context, thread Thread, first_action string) error { // Start the thread, error if double-started ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) @@ -515,18 +520,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error { return err } - err = UpdateStates(ctx, thread, NewLockInfo(thread, []string{"state"}), func(context *WriteContext) error { - err := thread.SetState("finished") - if err != nil { - return err - } - return UnlockLockables(context, []Lockable{thread}, thread) - }) - if err != nil { - ctx.Log.Logf("thread", "THREAD_LOOP_UNLOCK_ERR: %e", err) - return err - } + ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID()) @@ -578,13 +573,16 @@ func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lock return false } +// Helper function to start a child from a thread during a signal handler +// Starts a write context, so cannot be called from either a write or read context func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error { - return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"children"}), func(context *WriteContext) error { + context := NewWriteContext(ctx) + return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { child := thread.Child(signal.ID) if child == nil { return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) } - return UpdateMoreStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *WriteContext) error { + return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error { info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) info.Start = true @@ -595,11 +593,14 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro }) } +// Helper function to restore threads that should be running from a parents restore action +// Starts a write context, so cannot be called from either a write or read context func ThreadRestore(ctx * Context, thread Thread) { - UpdateStates(ctx, thread, NewLockMap( + context := NewWriteContext(ctx) + UpdateStates(context, thread, NewLockMap( NewLockInfo(thread, []string{"children"}), LockList(thread.Children(), []string{"start"}), - ), func(context *WriteContext)(error) { + ), func(context *StateContext)(error) { for _, child := range(thread.Children()) { should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) @@ -612,17 +613,14 @@ func ThreadRestore(ctx * Context, thread Thread) { }) } +// Helper function to be called during a threads start action, sets the thread state to started +// Starts a write context, so cannot be called from either a write or read context func ThreadStart(ctx * Context, thread Thread) error { - return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"start", "lock"}), func(context *WriteContext) error { - owner_id := NodeID{} - if thread.Owner() != nil { - owner_id = thread.Owner().ID() - } - if owner_id != thread.ID() { - err := LockLockables(context, []Lockable{thread}, thread) - if err != nil { - return err - } + context := NewWriteContext(ctx) + return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { + err := LockLockables(context, []Lockable{thread}, thread) + if err != nil { + return err } return thread.SetState("started") }) @@ -657,7 +655,8 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } case <- thread.Timeout(): timeout_action := "" - err := UpdateStates(ctx, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *WriteContext) error { + context := NewWriteContext(ctx) + err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error { timeout_action = thread.TimeoutAction() thread.ClearTimeout() return nil @@ -671,27 +670,46 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } } +func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) { + ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String()) + return "", ThreadFinish(ctx, thread) +} + +func ThreadFinish(ctx *Context, thread Thread) error { + context := NewWriteContext(ctx) + return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { + err := thread.SetState("finished") + if err != nil { + return err + } + return UnlockLockables(context, []Lockable{thread}, thread) + }) +} + var ThreadAbortedError = errors.New("Thread aborted by signal") -// Default thread abort is to return a ThreadAbortedError +// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error { + context := NewReadContext(ctx) + err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { return thread.Signal(context, NewStatusSignal("aborted", thread.ID())) }) if err != nil { return "", err } - return "", ThreadAbortedError + return "finish", ThreadAbortedError } -// Default thread cancel is to finish the thread +// Default thread action for "cancel", sends a signal and returns no error func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error { + context := NewReadContext(ctx) + err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { return thread.Signal(context, NewSignal("cancelled")) }) - return "", err + return "finish", err } +// Copy the default thread actions to a new ThreadActions map func NewThreadActions() ThreadActions{ actions := ThreadActions{} for k, v := range(BaseThreadActions) { @@ -701,6 +719,7 @@ func NewThreadActions() ThreadActions{ return actions } +// Copy the defult thread handlers to a new ThreadAction map func NewThreadHandlers() ThreadHandlers{ handlers := ThreadHandlers{} for k, v := range(BaseThreadHandlers) { @@ -710,12 +729,15 @@ func NewThreadHandlers() ThreadHandlers{ return handlers } +// Default thread actions var BaseThreadActions = ThreadActions{ "wait": ThreadWait, "start": ThreadDefaultStart, + "finish": ThreadDefaultFinish, "restore": ThreadDefaultRestore, } +// Default thread signal handlers var BaseThreadHandlers = ThreadHandlers{ "abort": ThreadAbort, "cancel": ThreadCancel,