From e347a3f2322b2d5ed8f1ef00d411047aa4d24150 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 22 Jul 2023 20:21:17 -0600 Subject: [PATCH] start maniacal rewrite, main goal is to combine node and lockable to remove any sync mutex deadlocks. Another goal is to make read contexts get copies of the state to ensure they don't modify and no lock is required to ensure no value changes, and write contexts use the lockable locks instead of mutex --- context.go | 2 +- gql.go | 56 +++++- gql_mutation.go | 63 +++--- gql_query.go | 4 +- gql_resolvers.go | 46 ++--- gql_test.go | 93 +++++---- gql_types.go | 23 +-- graph_test.go | 13 +- lockable.go | 314 +++++++++++++++--------------- lockable_test.go | 495 ----------------------------------------------- node.go | 284 +++++++++++++++++++++------ policy.go | 16 +- signal.go | 53 +++-- thread.go | 177 +++++++++-------- thread_test.go | 122 ------------ 15 files changed, 656 insertions(+), 1105 deletions(-) delete mode 100644 lockable_test.go delete mode 100644 thread_test.go diff --git a/context.go b/context.go index 2a130b3..d6fd5fa 100644 --- a/context.go +++ b/context.go @@ -212,7 +212,7 @@ func NewContext(db * badger.DB, log Logger) * Context { ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate) ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf) - ctx.GQL.Mutation.AddFieldConfig("sendUpdate", GQLMutationSendUpdate) + ctx.GQL.Mutation.AddFieldConfig("abort", GQLMutationAbort) ctx.GQL.Mutation.AddFieldConfig("startChild", GQLMutationStartChild) err = ctx.RebuildSchema() diff --git a/gql.go b/gql.go index 319dc97..a38cf15 100644 --- a/gql.go +++ b/gql.go @@ -204,7 +204,16 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"}) - err := UpdateStates(ctx, []Node{server, &new_user}, func(nodes NodeMap) error { + err := UpdateStates(ctx, server, NewLockMap(LockList{ + LockInfo{ + Node: server, + Resources: []string{"users"}, + }, + LockInfo{ + Node: &new_user, + Resources: []string{""}, + }, + }), func(context *WriteContext) error { server.Users[key_id] = &new_user return nil }) @@ -864,30 +873,59 @@ var gql_actions ThreadActions = ThreadActions{ }(server) - UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){ + err = UpdateStates(ctx, server, NewLockMap( + NewLockInfo(server, []string{"http_server"}), + ), func(context *WriteContext) error { server.tcp_listener = listener server.http_server = http_server - return server.Signal(ctx, NewSignal(server, "server_started"), nodes) + return nil + }) + + if err != nil { + return "", err + } + + err = UseStates(ctx, server, NewLockMap( + NewLockInfo(server, []string{"signal"}), + ), func(context *ReadContext) error { + return server.Signal(context, NewSignal("server_started")) }) + if err != nil { + return "", err + } + return "wait", nil }, } var gql_handlers ThreadHandlers = ThreadHandlers{ - "child_added": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { + "child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) - UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { - should_run, exists := thread.ChildInfo(signal.Source()).(*ParentThreadInfo) + err := UpdateStates(ctx, thread, NewLockMap( + NewLockInfo(thread, []string{"children"}), + ), func(context *WriteContext) error { + sig, ok := signal.(IDSignal) + if ok == false { + ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST") + return nil + } + should_run, exists := thread.ChildInfo(sig.ID).(*ParentThreadInfo) if exists == false { - ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") + ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) return nil } if should_run.Start == true { - ChildGo(ctx, thread, thread.Child(signal.Source()), should_run.StartAction) + ChildGo(ctx, thread, thread.Child(sig.ID), should_run.StartAction) } return nil }) + + if err != nil { + + } else { + + } return "wait", nil }, "start_child": func(ctx *Context, thread Thread, signal GraphSignal) (string, error) { @@ -902,7 +940,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ if err != nil { ctx.Log.Logf("gql", "GQL_START_CHILD_ERR: %s", err) } else { - ctx.Log.Logf("gql", "GQL_START_CHILD: %s", sig.ChildID.String()) + ctx.Log.Logf("gql", "GQL_START_CHILD: %s", sig.ID.String()) } return "wait", nil diff --git a/gql_mutation.go b/gql_mutation.go index 24f0974..b85f261 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -4,16 +4,13 @@ import ( "github.com/graphql-go/graphql" ) -var GQLMutationSendUpdate = NewField(func()*graphql.Field { - gql_mutation_send_update := &graphql.Field{ +var GQLMutationAbort = NewField(func()*graphql.Field { + gql_mutation_abort := &graphql.Field{ Type: GQLTypeSignal.Type, Args: graphql.FieldConfigArgument{ "id": &graphql.ArgumentConfig{ Type: graphql.String, }, - "signal": &graphql.ArgumentConfig{ - Type: GQLTypeSignalInput.Type, - }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { ctx, err := PrepResolve(p) @@ -21,50 +18,37 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field { return nil, err } - err = ctx.Server.Allowed("signal", "self", ctx.User) - if err != nil { - return nil, err - } - - signal_map, err := ExtractParam[map[string]interface{}](p, "signal") + err = ctx.Server.Allowed("signal", "", ctx.User) if err != nil { return nil, err } - var signal GraphSignal = nil - if signal_map["Direction"] == "up" { - signal = NewSignal(ctx.Server, signal_map["Type"].(string)) - } else if signal_map["Direction"] == "down" { - signal = NewDownSignal(ctx.Server, signal_map["Type"].(string)) - } else if signal_map["Direction"] == "direct" { - signal = NewDirectSignal(ctx.Server, signal_map["Type"].(string)) - } else { - return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) - } - id, err := ExtractID(p, "id") if err != nil { return nil, err } var node Node = nil - err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ - node = FindChild(ctx.Context, ctx.Server, id, nodes) + err = UseStates(ctx.Context, ctx.User, NewLockMap( + NewLockInfo(ctx.Server, []string{"children"}), + ), func(context *ReadContext) (error){ + node = FindChild(ctx.Context, ctx.User, ctx.Server, id, locked) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - node.Signal(ctx.Context, signal, nodes) - return nil + return UseMoreStates(ctx.Context, locked, ctx.User, NewLockInfo(node, []string{"signal"}), func(locked NodeLockMap) error { + return node.Signal(ctx.Context, AbortSignal, locked) + }) }) if err != nil { return nil, err } - return signal, nil + return AbortSignal, nil }, } - return gql_mutation_send_update + return gql_mutation_abort }) var GQLMutationStartChild = NewField(func()*graphql.Field{ @@ -88,11 +72,6 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ return nil, err } - err = ctx.Server.Allowed("start_child", "self", ctx.User) - if err != nil { - return nil, err - } - parent_id, err := ExtractID(p, "parent_id") if err != nil { return nil, err @@ -109,14 +88,22 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } var signal GraphSignal - err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ - node := FindChild(ctx.Context, ctx.Server, parent_id, nodes) + err = UseStates(ctx.Context, ctx.User, NewLockMap( + NewLockInfo(ctx.Server, []string{"children"}), + ), func(context *ReadContext) error { + node := FindChild(ctx.Context, ctx.User, ctx.Server, parent_id, locked) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) } - return UseMoreStates(ctx.Context, []Node{node}, nodes, func(NodeMap) error { - signal = NewStartChildSignal(ctx.Server, child_id, action) - return node.Signal(ctx.Context, signal, nodes) + + err := node.Allowed("signal", "", ctx.User) + if err != nil { + return err + } + + return UseMoreStates(ctx.Context, locked, ctx.User, NewLockInfo(node, []string{"start_child", "signal"}), func(locked NodeLockMap) error { + signal = NewStartChildSignal(child_id, action) + return node.Signal(ctx.Context, signal, locked) }) }) if err != nil { diff --git a/gql_query.go b/gql_query.go index ca8ea27..4593e10 100644 --- a/gql_query.go +++ b/gql_query.go @@ -11,7 +11,7 @@ var GQLQuerySelf = &graphql.Field{ return nil, err } - err = ctx.Server.Allowed("enumerate", "self", ctx.User) + err = ctx.Server.Allowed("read", "", ctx.User) if err != nil { return nil, err } @@ -28,7 +28,7 @@ var GQLQueryUser = &graphql.Field{ return nil, err } - err = ctx.User.Allowed("enumerate", "self", ctx.User) + err = ctx.User.Allowed("read", "", ctx.User) if err != nil { return nil, err } diff --git a/gql_resolvers.go b/gql_resolvers.go index 4c57376..a8d897f 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -53,8 +53,8 @@ func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { return nil, fmt.Errorf("Failed to cast source to Node") } - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { - return node.Allowed("read", "id", ctx.User) + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"id"}), func(locked NodeLockMap) error { + return nil }) if err != nil { return nil, err @@ -76,9 +76,9 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { listen := "" - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"listen"}), func(locked NodeLockMap) error { listen = node.Listen - return node.Allowed("read", "listen", ctx.User) + return nil }) if err != nil { @@ -100,9 +100,9 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { } var parent Thread = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"parent"}), func(locked NodeLockMap) error { parent = node.Parent() - return node.Allowed("read", "parent", ctx.User) + return nil }) if err != nil { @@ -124,9 +124,9 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { } var state string - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"state"}), func(locked NodeLockMap) error { state = node.State() - return node.Allowed("read", "state", ctx.User) + return nil }) if err != nil { @@ -148,9 +148,9 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { } var children []Thread = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"children"}), func(locked NodeLockMap) error { children = node.Children() - return node.Allowed("read", "children", ctx.User) + return nil }) if err != nil { @@ -172,9 +172,9 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { } name := "" - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"name"}), func(locked NodeLockMap) error { name = node.Name() - return node.Allowed("read", "name", ctx.User) + return nil }) if err != nil { @@ -196,9 +196,9 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { } var requirements []Lockable = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"requirements"}), func(locked NodeLockMap) error { requirements = node.Requirements() - return node.Allowed("read", "requirements", ctx.User) + return nil }) if err != nil { @@ -220,9 +220,9 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { } var dependencies []Lockable = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"dependencies"}), func(locked NodeLockMap) error { dependencies = node.Dependencies() - return node.Allowed("read", "dependencies", ctx.User) + return nil }) if err != nil { @@ -244,9 +244,9 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { } var owner Node = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"owner"}), func(locked NodeLockMap) error { owner = node.Owner() - return node.Allowed("read", "owner", ctx.User) + return nil }) if err != nil { @@ -268,14 +268,14 @@ func GQLThreadUsers(p graphql.ResolveParams) (interface{}, error) { } var users []*User - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"users"}), func(locked NodeLockMap) error { users = make([]*User, len(node.Users)) i := 0 for _, user := range(node.Users) { users[i] = user i += 1 } - return node.Allowed("read", "users", ctx.User) + return nil }) if err != nil { @@ -298,12 +298,6 @@ func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { }) } -func GQLSignalSource(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.Source(), nil - }) -} - func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ direction := signal.Direction() diff --git a/gql_test.go b/gql_test.go index 7a92aaf..385b07b 100644 --- a/gql_test.go +++ b/gql_test.go @@ -31,22 +31,24 @@ func TestGQLThread(t * testing.T) { t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) t2 := &t2_r - err = UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error { - i1 := NewParentThreadInfo(true, "start", "restore") - err := LinkThreads(ctx, gql_t, t1, &i1, nodes) - if err != nil { - return err - } - - i2 := NewParentThreadInfo(false, "start", "restore") - return LinkThreads(ctx, gql_t, t2, &i2, nodes) + err = UpdateStates(ctx, gql_t, RequestList([]Node{t1, t2}, []string{"parent"}), func(locked NodeLockMap) error { + return UpdateMoreStates(ctx, locked, gql_t, NewLockRequest(gql_t, []string{"children"}), func(locked NodeLockMap) error { + i1 := NewParentThreadInfo(true, "start", "restore") + err := LinkThreads(ctx, gql_t, t1, &i1, locked) + if err != nil { + return err + } + + i2 := NewParentThreadInfo(false, "start", "restore") + return LinkThreads(ctx, gql_t, t2, &i2, locked) + }) }) fatalErr(t, err) go func(thread Thread){ time.Sleep(10*time.Millisecond) - err := UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - return thread.Signal(ctx, CancelSignal(nil), nodes) + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) }(gql_t) @@ -81,39 +83,49 @@ func TestGQLDBLoad(t * testing.T) { gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") - err = UpdateStates(ctx, []Node{gql, t1, l1, u1, p1}, func(nodes NodeMap) error { - err := gql.AddPolicy(p1) - if err != nil { - return err - } - gql.Users[KeyID(&u1_key.PublicKey)] = u1 - err = LinkLockables(ctx, gql, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkThreads(ctx, gql, t1, &info, nodes) + err = UpdateStates(ctx, gql, NewLockRequest(gql, []string{"policies", "users", "requirements", "children"}), func(locked NodeLockMap) error { + return UpdateMoreStates(ctx, locked, gql, RequestList([]Node{u1, p1}, []string{}), func(locked NodeLockMap) error { + err := gql.AddPolicy(p1) + if err != nil { + return err + } + + gql.Users[KeyID(&u1_key.PublicKey)] = u1 + + return UpdateMoreStates(ctx, locked, gql, NewLockRequest(t1, []string{"parent"}), func(locked NodeLockMap) error { + err := LinkThreads(ctx, gql, t1, &info, locked) + if err != nil { + return err + } + return UpdateMoreStates(ctx, locked, gql, NewLockRequest(l1, []string{"dependencies"}), func(locked NodeLockMap) error { + return LinkLockables(ctx, gql, gql, []Lockable{l1}, locked) + }) + }) + }) }) fatalErr(t, err) - err = UseStates(ctx, []Node{gql}, func(nodes NodeMap) error { - err := gql.Signal(ctx, NewSignal(t1, "child_added"), nodes) + err = UseStates(ctx, gql, NewLockRequest(gql, []string{"signal"}), func(locked NodeLockMap) error { + err := gql.Signal(ctx, NewStatusSignal("child_linked", t1.ID()), locked) if err != nil { return nil } - return gql.Signal(ctx, AbortSignal(nil), nodes) + return gql.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") - if errors.Is(err, NewThreadAbortedError(NodeID{})) { + if errors.Is(err, ThreadAbortedError) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) - } else { + } else if err != nil{ fatalErr(t, err) + } else { + ctx.Log.Logf("test", "Main thread cancelled by signal") } - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", t1, 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") - err = UseStates(ctx, []Node{gql, u1}, func(nodes NodeMap) error { + err = UseStates(ctx, gql, RequestList([]Node{gql, u1}, nil), func(locked NodeLockMap) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -126,29 +138,29 @@ func TestGQLDBLoad(t * testing.T) { var t1_loaded *SimpleThread = nil var update_channel_2 chan GraphSignal - err = UseStates(ctx, []Node{gql_loaded}, func(nodes NodeMap) error { + err = UseStates(ctx, gql, NewLockRequest(gql_loaded, []string{"users", "children"}), func(locked NodeLockMap) error { ser, err := gql_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) - err = UseMoreStates(ctx, []Node{u_loaded}, nodes, func(nodes NodeMap) error { + err = UseMoreStates(ctx, locked, gql, NewLockRequest(u_loaded, nil), func(locked NodeLockMap) error { ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) - gql_loaded.Signal(ctx, AbortSignal(nil), nodes) + gql_loaded.Signal(ctx, AbortSignal, locked) return err }) err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") - if errors.Is(err, NewThreadAbortedError(NodeID{})) { + if errors.Is(err, ThreadAbortedError) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) } else { fatalErr(t, err) } - (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", t1_loaded, 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") + (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") } @@ -157,20 +169,21 @@ func TestGQLAuth(t * testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - p1_r := NewPerTagPolicy(RandID(), map[string]NodeActions{"gql": NewNodeActions(nil, []string{"*"})}) + p1_r := NewPerTagPolicy(RandID(), map[string]NodeActions{"gql": NewNodeActions(nil, []string{"read"})}) p1 := &p1_r gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r - err = UpdateStates(ctx, []Node{gql_t, p1}, func(nodes NodeMap) error { + // p1 not written to DB, TODO: update write to follow links maybe + err = UpdateStates(ctx, gql_t, NewLockRequest(gql_t, []string{"policies"}), func(locked NodeLockMap) error { return gql_t.AddPolicy(p1) }) done := make(chan error, 1) var update_channel chan GraphSignal - err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { + err = UseStates(ctx, gql_t, NewLockRequest(gql_t, nil), func(locked NodeLockMap) error { update_channel = UpdateChannel(gql_t, 10, NodeID{}) return nil }) @@ -184,14 +197,14 @@ func TestGQLAuth(t * testing.T) { case <-done: ctx.Log.Logf("test", "DONE") } - err := UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { - return thread.Signal(ctx, CancelSignal(nil), nodes) + err := UseStates(ctx, gql_t, NewLockRequest(gql_t, []string{"signal}"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) }(done, gql_t) go func(thread Thread){ - (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port ctx.Log.Logf("test", "GQL_PORT: %d", port) diff --git a/gql_types.go b/gql_types.go index f8bb852..ecbf1db 100644 --- a/gql_types.go +++ b/gql_types.go @@ -162,7 +162,7 @@ var GQLTypeGraphNode = NewSingleton(func() *graphql.Object { var GQLTypeSignal = NewSingleton(func() *graphql.Object { gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ - Name: "SignalOut", + Name: "Signal", IsTypeOf: func(p graphql.IsTypeOfParams) bool { _, ok := p.Value.(GraphSignal) return ok @@ -174,10 +174,6 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object { Type: graphql.String, Resolve: GQLSignalType, }) - gql_type_signal.AddFieldConfig("Source", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalSource, - }) gql_type_signal.AddFieldConfig("Direction", &graphql.Field{ Type: graphql.String, Resolve: GQLSignalDirection, @@ -189,20 +185,3 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object { return gql_type_signal }, nil) -var GQLTypeSignalInput = NewSingleton(func()*graphql.InputObject { - gql_type_signal_input := graphql.NewInputObject(graphql.InputObjectConfig{ - Name: "SignalIn", - Fields: graphql.InputObjectConfigFieldMap{}, - }) - gql_type_signal_input.AddFieldConfig("Type", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - DefaultValue: "cancel", - }) - gql_type_signal_input.AddFieldConfig("Direction", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - DefaultValue: "down", - }) - - return gql_type_signal_input -}, nil) - diff --git a/graph_test.go b/graph_test.go index 5ab8c88..54a7521 100644 --- a/graph_test.go +++ b/graph_test.go @@ -12,7 +12,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, source Node, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, timeout time.Duration, str string) GraphSignal { timeout_channel := time.After(timeout) for true { select { @@ -22,16 +22,7 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si t.Fatal(str) } if signal.Type() == signal_type { - ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) - if source == nil { - if signal.Source() == ZeroID { - return signal - } - } else { - if signal.Source() == source.ID() { - return signal - } - } + return signal } case <-timeout_channel: pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) diff --git a/lockable.go b/lockable.go index 08b70b4..9bdee4c 100644 --- a/lockable.go +++ b/lockable.go @@ -215,52 +215,58 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { return nil } -// Lockable.Signal sends the update to the owner, requirements, and dependencies before updating listeners -func (lockable * SimpleLockable) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { - err := lockable.GraphNode.Signal(ctx, signal, nodes) +// Assumed that lockable is already locked for signal +func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal) error { + err := lockable.GraphNode.Signal(ctx, signal, locked) if err != nil { return err } - if signal.Direction() == Up { - // Child->Parent, lockable updates dependency lockables - owner_sent := false - UseMoreStates(ctx, NodeList(lockable.dependencies), nodes, func(nodes NodeMap) error { - for _, dependency := range(lockable.dependencies){ + + switch signal.Direction() { + case Up: + err = UseMoreStates(ctx, locked, lockable, NewLockMap( + NewLockInfo(lockable, []string{"dependencies", "owner"}), + RequestList(lockable.requirements, []string{"signal"}), + ), func(context *ReadContext) error { + owner_sent := false + for _, dependency := range(lockable.dependencies) { ctx.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) - dependency.Signal(ctx, signal, nodes) + dependency.Signal(ctx, signal, locked) if lockable.owner != nil { if dependency.ID() == lockable.owner.ID() { owner_sent = true } } } + if lockable.owner != nil && owner_sent == false { + if lockable.owner.ID() != lockable.ID() { + ctx.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) + return UseMoreStates(context, lockable, NewLockRequest(lockable.owner, []string{"signal"}), func(context *ReadContext) error { + return lockable.owner.Signal(ctx, signal, locked) + }) + } + } return nil }) - if lockable.owner != nil && owner_sent == false { - if lockable.owner.ID() != lockable.ID() { - ctx.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - UseMoreStates(ctx, []Node{lockable.owner}, nodes, func(nodes NodeMap) error { - return lockable.owner.Signal(ctx, signal, nodes) - }) - } - } - } else if signal.Direction() == Down { - // Parent->Child, lockable updates lock holder - UseMoreStates(ctx, NodeList(lockable.requirements), nodes, func(nodes NodeMap) error { + case Down: + err = UseMoreStates(context, lockable, NewLockMap( + NewLockInfo(lockable, []string{"requirements"}, + RequestList(lockable.requirements, []string{"signal"})), + ), func(context *ReadContext) error { for _, requirement := range(lockable.requirements) { - err := requirement.Signal(ctx, signal, nodes) + err := requirement.Signal(ctx, signal, locked) if err != nil { return err } } return nil }) - - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + case Direct: + err = nil + default: + return fmt.Errorf("invalid signal direction %d", signal.Direction()) } - return nil + return err } // Removes requirement as a requirement from lockable @@ -286,7 +292,7 @@ func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) err // Link requirements as requirements to lockable // Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks -func LinkLockables(ctx * Context, lockable Lockable, requirements []Lockable, nodes NodeMap) error { +func LinkLockables(context *WriteContext, princ Node, lockable Lockable, requirements []Lockable) error { if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } @@ -312,56 +318,61 @@ func LinkLockables(ctx * Context, lockable Lockable, requirements []Lockable, no found[requirement.ID()] = true } - // Check that all the requirements can be added - // If the lockable is already locked, need to lock this resource as well before we can add it - for _, requirement := range(requirements) { - for _, req := range(requirements) { - if req.ID() == requirement.ID() { - continue + return UpdateMoreStates(context, princ, NewInfoMap( + NewLockInfo(lockable, []string{"requirements"}), + RequestList(requirements, []string{"dependencies"}), + ), func(context *WriteContext) error { + // Check that all the requirements can be added + // If the lockable is already locked, need to lock this resource as well before we can add it + for _, requirement := range(requirements) { + for _, req := range(requirements) { + if req.ID() == requirement.ID() { + continue + } + if checkIfRequirement(context, req, requirement) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + } } - if checkIfRequirement(ctx, req, requirement, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + if checkIfRequirement(context, lockable, requirement) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) } - } - if checkIfRequirement(ctx, lockable, requirement, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) - } - if checkIfRequirement(ctx, requirement, lockable, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) - } - if lockable.Owner() == nil { - // If the new owner isn't locked, we can add the requirement - } else if requirement.Owner() == nil { - // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) - } else { - // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement.Owner().ID() != lockable.Owner().ID() { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + if checkIfRequirement(context, requirement, lockable) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + } + if lockable.Owner() == nil { + // If the new owner isn't locked, we can add the requirement + } else if requirement.Owner() == nil { + // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) + } else { + // If the new requirement is already locked and the owner is already locked, their owners need to match + if requirement.Owner().ID() != lockable.Owner().ID() { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + } } } - } - // Update the states of the requirements - for _, requirement := range(requirements) { - requirement.AddDependency(lockable) - lockable.AddRequirement(requirement) - ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) - } + // Update the states of the requirements + for _, requirement := range(requirements) { + requirement.AddDependency(lockable) + lockable.AddRequirement(requirement) + ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) + } - // Return no error - return nil + // Return no error + return nil + }) } // Must be called withing update context -func checkIfRequirement(ctx * Context, r Lockable, cur Lockable, nodes NodeMap) bool { +func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool { for _, c := range(cur.Requirements()) { if c.ID() == r.ID() { return true } is_requirement := false - UpdateMoreStates(ctx, []Node{c}, nodes, func(nodes NodeMap) (error) { - is_requirement = checkIfRequirement(ctx, cur, c, nodes) + UpdateMoreStates(context, cur, NewLockMap(NewLockRequest(c, []string{"requirements"})), func(context *WriteContext) error { + is_requirement = checkIfRequirement(context, cur, c) return nil }) @@ -374,19 +385,18 @@ func checkIfRequirement(ctx * Context, r Lockable, cur Lockable, nodes NodeMap) } // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error -// Requires that all the nodes in to_lock and new_owner are locked for write -func LockLockables(ctx * Context, to_lock []Lockable, new_owner Lockable, nodes NodeMap) error { +// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call +func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } - node_list := make([]Node, len(to_lock)) - for i, l := range(to_lock) { + for _, l := range(to_lock) { if l == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") } - node_list[i] = l } + if new_owner == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") } @@ -396,132 +406,120 @@ func LockLockables(ctx * Context, to_lock []Lockable, new_owner Lockable, nodes return nil } - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_lock) { - ctx.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) + return UpdateMoreStates(ctx, locked, new_owner, RequestList(to_lock, []string{"lock"}), func(locked NodeLockMap) error { + // First loop is to check that the states can be locked, and locks all requirements + for _, req := range(to_lock) { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) - // Check custom lock conditions - err := req.CanLock(new_owner) - if err != nil { - return err - } + // Check custom lock conditions + err := req.CanLock(new_owner) + if err != nil { + return err + } - // If req is alreay locked, check that we can pass the lock - if req.Owner() != nil { - owner := req.Owner() - if owner.ID() == new_owner.ID() { - // If we already own the lock, nothing to do - continue - } else if owner.ID() == req.ID() { - if req.AllowedToTakeLock(new_owner, req) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) - } - err := LockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err + // If req is alreay locked, check that we can pass the lock + if req.Owner() != nil { + owner := req.Owner() + if owner.ID() == new_owner.ID() { + continue + } else { + err := UpdateMoreStates(ctx, locked, new_owner, NewLockRequest(owner, []string{"take_lock"}), func(locked NodeLockMap)(error){ + return LockLockables(ctx, req.Requirements(), req, locked) + }) + if err != nil { + return err + } } } else { - err := UpdateMoreStates(ctx, []Node{owner}, nodes, func(nodes NodeMap)(error){ - if owner.AllowedToTakeLock(new_owner, req) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) - } - err := LockLockables(ctx, req.Requirements(), req, nodes) - return err - }) + err := LockLockables(ctx, req.Requirements(), req, locked) if err != nil { return err } } - } else { - err := LockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err - } } - } - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_lock) { - old_owner := req.Owner() - // If the lockable was previously unowned, update the state - if old_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) - req.SetOwner(new_owner) - new_owner.RecordLock(req, old_owner) - // Otherwise if the new owner already owns it, no need to update state - } else if old_owner.ID() == new_owner.ID() { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) - // Otherwise update the state - } else { - req.SetOwner(new_owner) - new_owner.RecordLock(req, old_owner) - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) + // At this point state modification will be started, so no errors can be returned + for _, req := range(to_lock) { + old_owner := req.Owner() + // If the lockable was previously unowned, update the state + if old_owner == nil { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) + req.SetOwner(new_owner) + new_owner.RecordLock(req, old_owner) + // Otherwise if the new owner already owns it, no need to update state + } else if old_owner.ID() == new_owner.ID() { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) + // Otherwise update the state + } else { + req.SetOwner(new_owner) + new_owner.RecordLock(req, old_owner) + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) + } } - } - return nil + return nil + }) + } -// Unlock nodes in the to_unlock slice with old_owner, does not modify any states if returning an error -// Requires that all the nodes in to_unlock and old_owner are locked for write -func UnlockLockables(ctx * Context, to_unlock []Lockable, old_owner Lockable, nodes NodeMap) error { +func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lockable) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } + for _, l := range(to_unlock) { if l == nil { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not lock nil") + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil") } } + if old_owner == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") } - // Called with no requirements to lock, success + // Called with no requirements to unlock, success if len(to_unlock) == 0 { return nil } - node_list := make([]Node, len(to_unlock)) - for i, l := range(to_unlock) { - node_list[i] = l - } + return UpdateMoreStates(ctx, locked, old_owner, RequestList(to_unlock, []string{"lock"}), func(locked NodeLockMap) error { + // First loop is to check that the states can be locked, and locks all requirements + for _, req := range(to_unlock) { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_unlock) { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) + // Check if the owner is correct + if req.Owner() != nil { + if req.Owner().ID() != old_owner.ID() { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) + } + } else { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) + } - // Check if the owner is correct - if req.Owner() != nil { - if req.Owner().ID() != old_owner.ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) + // Check custom unlock conditions + err := req.CanUnlock(old_owner) + if err != nil { + return err } - } else { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) - } - // Check custom unlock conditions - err := req.CanUnlock(old_owner) - if err != nil { - return err + err = UnlockLockables(ctx, req.Requirements(), req, locked) + if err != nil { + return err + } } - err = UnlockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err + // At this point state modification will be started, so no errors can be returned + for _, req := range(to_unlock) { + new_owner := old_owner.RecordUnlock(req) + req.SetOwner(new_owner) + if new_owner == nil { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) + } else { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) + } } - } - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_unlock) { - new_owner := old_owner.RecordUnlock(req) - req.SetOwner(new_owner) - if new_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) - } else { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) - } - } - return nil + return nil + }) } // Load function for SimpleLockable diff --git a/lockable_test.go b/lockable_test.go deleted file mode 100644 index 2be1608..0000000 --- a/lockable_test.go +++ /dev/null @@ -1,495 +0,0 @@ -package graphvent - -import ( - "testing" - "fmt" - "time" -) - -func TestNewSimpleLockable(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - l1_deps := len(l1.Dependencies()) - if l1_deps != 1 { - return fmt.Errorf("l1 has wront amount of dependencies %d/1", l1_deps) - } - - l1_dep1 := l1.Dependencies()[0] - if l1_dep1.ID() != l2.ID() { - return fmt.Errorf("Wrong dependency for l1, %s instead of %s", l1_dep1.ID(), l2.ID()) - } - - l2_reqs := len(l2.Requirements()) - if l2_reqs != 1 { - return fmt.Errorf("l2 has wrong amount of requirements %d/1", l2_reqs) - } - - l2_req1 := l2.Requirements()[0] - if l2_req1.ID() != l1.ID() { - return fmt.Errorf("Wrong requirement for l2, %s instead of %s", l2_req1.ID(), l1.ID()) - } - return nil - }) - fatalErr(t, err) -} - -func TestRepeatedChildLockable(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1, l1}, nodes) - }) - - if err == nil { - t.Fatal("Added the same lockable as a requirement twice to the same lockable") - } -} - -func TestLockableSelfLock(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - - err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l1.ID() { - return fmt.Errorf("l1 is owned by %s instead of self", owner_id) - } - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - if l1.owner != nil { - return fmt.Errorf("l1 is not unowned after unlock: %s", l1.owner.ID()) - } - return nil - }) - - fatalErr(t, err) -} - -func TestLockableSelfLockTiered(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { - owner_1 := NodeID{} - if l1.owner != nil { - owner_1 = l1.owner.ID() - } - if owner_1 != l3.ID() { - return fmt.Errorf("l1 is owned by %s instead of l3", owner_1) - } - - owner_2 := NodeID{} - if l2.owner != nil { - owner_2 = l2.owner.ID() - } - if owner_2 != l3.ID() { - return fmt.Errorf("l2 is owned by %s instead of l3", owner_2) - } - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { - owner_1 := l1.owner - if owner_1 != nil { - return fmt.Errorf("l1 is not unowned after unlocking: %s", owner_1.ID()) - } - - owner_2 := l2.owner - if owner_2 != nil { - return fmt.Errorf("l2 is not unowned after unlocking: %s", owner_2.ID()) - } - - owner_3 := l3.owner - if owner_3 != nil { - return fmt.Errorf("l3 is not unowned after unlocking: %s", owner_3.ID()) - } - return nil - }) - - fatalErr(t, err) -} - -func TestLockableLockOther(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{l1}, l2, nodes) - fatalErr(t, err) - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l2.ID() { - return fmt.Errorf("l1 is owned by %s instead of l2", owner_id) - } - - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { - err := UnlockLockables(ctx, []Lockable{l1}, l2, nodes) - fatalErr(t, err) - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("l1 is owned by %s instead of l2", owner.ID()) - } - - return nil - }) - fatalErr(t, err) - -} - -func TestLockableLockSimpleConflict(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{l1}, l2, nodes) - if err == nil { - t.Fatal("l2 took l1's lock from itself") - } - - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l1.ID() { - return fmt.Errorf("l1 is owned by %s instead of l1", owner_id) - } - - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("l1 is owned by %s instead of l1", owner.ID()) - } - - return nil - }) - fatalErr(t, err) - -} - -func TestLockableLockTieredConflict(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l2}, l2, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - if err == nil { - t.Fatal("Locked l3 which depends on l1 while l2 which depends on l1 is already locked") - } -} - -func TestLockableSimpleUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - - - update_channel := UpdateChannel(l1, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return l1.Signal(ctx, NewDirectSignal(l1, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1") -} - -func TestLockableDownUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l2}, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l1, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return l2.Signal(ctx, NewDownSignal(l2, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") -} - -func TestLockableUpUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l2}, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l3, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return l2.Signal(ctx, NewSignal(l2, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") -} - -func TestOwnerNotUpdatedTwice(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l2}, l2, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l2, 1, NodeID{}) - - go func() { - err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return l1.Signal(ctx, NewSignal(l1, "test_update"), nodes) - }) - fatalErr(t, err) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") - (*GraphTester)(t).CheckForNone(update_channel, "Second update received on dependency") -} - -func TestLockableDependencyOverlap(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - - return LinkLockables(ctx, l3, []Lockable{l2, l1}, nodes) - }) - if err == nil { - t.Fatal("Should have thrown an error because of dependency overlap") - } -} - -func TestLockableDBLoad(t * testing.T){ - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - l4_r := NewSimpleLockable(RandID(), "Test Lockable 4") - l4 := &l4_r - l5_r := NewSimpleLockable(RandID(), "Test Lockable 5") - l5 := &l5_r - l6_r := NewSimpleLockable(RandID(), "Test Lockable 6") - l6 := &l6_r - err := UpdateStates(ctx, []Node{l1, l2, l3, l4, l5, l6}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) - if err != nil { - return err - } - - err = LinkLockables(ctx, l4, []Lockable{l3}, nodes) - if err != nil { - return err - } - - err = LinkLockables(ctx, l5, []Lockable{l4}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l3}, l6, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l3}, func(nodes NodeMap) error { - ser, err := l3.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - fatalErr(t, err) - - l3_loaded, err := LoadNode(ctx, l3.ID()) - fatalErr(t, err) - - // TODO: add more equivalence checks - err = UseStates(ctx, []Node{l3_loaded}, func(nodes NodeMap) error { - ser, err := l3_loaded.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - fatalErr(t, err) -} - -func TestLockableUnlink(t * testing.T){ - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UnlinkLockables(ctx, l2, l1) - fatalErr(t, err) -} diff --git a/node.go b/node.go index 08b7cbd..8d93487 100644 --- a/node.go +++ b/node.go @@ -54,6 +54,24 @@ func RandID() NodeID { return NodeID(uuid.New()) } +// A Node represents a data that can be locked and held by other Nodes +type Node interface { + ID() NodeID + Type() NodeType + Serialize() ([]byte, error) + + Allows(resouce, action string, principal Node) error + AddPolicy(Policy) error + RemovePolicy(Policy) error + + Signal(context *ReadContext, signal GraphSignal) error + + + Requirements() []Node + Dependencies() []Node + Owner() Node +} + // A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) // RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node type Node interface { @@ -74,7 +92,7 @@ type Node interface { RemovePolicy(Policy) error // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error + Signal(context *ReadContext, signal GraphSignal) error // Register a channel to receive updates sent to the node RegisterChannel(id NodeID, listener chan GraphSignal) // Unregister a channel from receiving updates sent to the node @@ -193,20 +211,20 @@ func (node * GraphNode) Type() NodeType { // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it -func (node * GraphNode) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { - ctx.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) +func (node * GraphNode) Signal(context *ReadContext, signal GraphSignal) error { + context.Context.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) node.listeners_lock.Lock() defer node.listeners_lock.Unlock() closed := []NodeID{} for id, listener := range node.listeners { - ctx.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) + context.Context.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) select { case listener <- signal: default: - ctx.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) + context.Context.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) go func(node Node, listener chan GraphSignal) { - listener <- NewDirectSignal(node, "listener_closed") + listener <- NewDirectSignal("listener_closed") close(listener) }(node, listener) closed = append(closed, id) @@ -293,18 +311,19 @@ func getNodeBytes(node Node) ([]byte, error) { } // Write multiple nodes to the database in a single transaction -func WriteNodes(ctx * Context, nodes NodeMap) error { - ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) - if nodes == nil { +func WriteNodes(context *WriteContext) error { + if locked == nil { return fmt.Errorf("Cannot write nil map to DB") } + context.Context.Log.Logf("db", "DB_WRITES: %d", len(context.Locked)) - serialized_bytes := make([][]byte, len(nodes)) - serialized_ids := make([][]byte, len(nodes)) + serialized_bytes := make([][]byte, len(context.Locked)) + serialized_ids := make([][]byte, len(context.Locked)) i := 0 - for _, node := range(nodes) { + for _, lock := range(context.Locked) { + node := lock.Node node_bytes, err := getNodeBytes(node) - ctx.Log.Logf("db", "DB_WRITE: %+v", node) + context.Context.Log.Logf("db", "DB_WRITE: %+v", node) if err != nil { return err } @@ -317,7 +336,7 @@ func WriteNodes(ctx * Context, nodes NodeMap) error { i++ } - err := ctx.DB.Update(func(txn *badger.Txn) error { + err := context.Context.DB.Update(func(txn *badger.Txn) error { for i, id := range(serialized_ids) { err := txn.Set(id, serialized_bytes[i]) if err != nil { @@ -406,60 +425,165 @@ func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) { return node, nil } -// Internal function to filter duplicate nodes from a list -func filterDuplicates(nodes []Node) []Node { - ret := []Node{} - found := map[NodeID]bool{} - for _, node := range(nodes) { - if node == nil { - return []Node{} - } +func NewLockInfo(node Node, resources []string) LockInfo { + return LockInfo{ + Node: node, + Resources: resources, + } +} - _, exists := found[node.ID()] - if exists == false { - found[node.ID()] = true - ret = append(ret, node) +type LockInfoList interface { + List() []LockInfo +} + +func NewLockMap(requests ...LockInfoList) LockMap { + reqs := LockMap{} + for _, req := range(requests) { + for _, info := range(req) { + reqs[req.Node.ID()] = info } } - return ret + return reqs } -// Convert any slice of types that implement Node to a []Node -func NodeList[K Node](list []K) []Node { - nodes := make([]Node, len(list)) +func RequestList[K Node](list []K, resources []string) LockList { + requests := make(LockList{}, len(list)) for i, node := range(list) { - nodes[i] = node + requests[i] = LockInfo{ + Node: node, + Resources: resources, + } } - return nodes + return requests } + type NodeMap map[NodeID]Node -type NodesFn func(nodes NodeMap) error -// Initiate a read context for nodes and call nodes_fn with init_nodes locked for read -func UseStates(ctx * Context, init_nodes []Node, nodes_fn NodesFn) error { - nodes := NodeMap{} - return UseMoreStates(ctx, init_nodes, nodes, nodes_fn) + +type LockInfo struct { + Node Node + Resources []string +} +func (info LockInfo) List() []LockInfo { + return []LockInfo{info} +} + +type LockMap map[NodeID]LockInfo +func (m LockMap) List() []LockInfo { + infos := make([]LockInfo, len(m)) + i := 0 + for _, info := range(m) { + infos[i] = info + i += 1 + } + + return infos } +type LockList []LockInfo +func (li LockList) List() []LockInfo { + return li +} -// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read -func UseMoreStates(ctx * Context, new_nodes []Node, nodes NodeMap, nodes_fn NodesFn) error { - new_nodes = filterDuplicates(new_nodes) +type ReadContext struct { + Graph *Context + Locked LockMap +} +type ReadFn func(*ReadContext)(error) + +type WriteContext struct { + Graph *Context + Locked LockMap +} +type WriteFn func(*WriteContext)(error) +func del[K comparable](list []K, val K) []K { + idx := -1 + for i, v := range(list) { + if v == val { + idx = i + break + } + } + if idx == -1 { + return nil + } + + list[idx] = list[len(list)-1] + return list[:len(list)-1] +} + +// Start a read context for node under ctx for the resources specified in init_nodes, then run nodes_fn +func UseStates(ctx *Context, node Node, nodes LockMap, read_fn ReadFn) error { + context := &ReadContext{ + Context: ctx, + Locked: LockMap{}, + } + return UseMoreStates(context, node, nodes, read_fn) +} + +// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read +// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read +func UseMoreStates(context *ReadContext, node Node, new_nodes LockMap, read_fn ReadFn) error { locked_nodes := []Node{} - for _, node := range(new_nodes) { - _, locked := nodes[node.ID()] - if locked == false { - node.RLock() - nodes[node.ID()] = node - locked_nodes = append(locked_nodes, node) + new_permissions := LockMap{} + for _, request := range(new_nodes) { + id := request.Node.ID() + new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} + for _, resource := range(request.Resources) { + // If the permission for this resource is already granted, continue + current_permissions, exists := context.Locked[id] + if exists == true { + already_granted := false + for _, r := range(current_permissions.Resources) { + if r == resource { + already_granted = true + break + } + } + if already_granted == true { + continue + } + } + + err := request.Node.Allowed("read", resource, node) + if err != nil { + return err + } + + tmp := new_permissions[id] + tmp.Resources = append(tmp.Resources, resource) + new_permissions[id] = tmp + } + + + req_perms, exists := new_permissions[id] + if exists == true { + cur_perms, already_locked := context.Locked[id] + if already_locked == false { + request.Node.RLock() + context.Locked[id] = req_perms + locked_nodes = append(locked_nodes, request.Node) + } else { + cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) + } } } - err := nodes_fn(nodes) + err := read_fn(context) + + for _, request := range(new_permissions) { + cur_perms := context.Locked[request.Node.ID()] + new_perms := cur_perms.Resources + for _, resource := range(cur_perms.Resources) { + new_perms = del(new_perms, resource) + } + cur_perms.Resources = new_perms + context.Locked[request.Node.ID()].Resources = new_perms + } for _, node := range(locked_nodes) { - delete(nodes, node.ID()) + delete(context.Locked, node.ID()) node.RUnlock() } @@ -467,30 +591,68 @@ func UseMoreStates(ctx * Context, new_nodes []Node, nodes NodeMap, nodes_fn Node } // Initiate a write context for nodes and call nodes_fn with nodes locked for read -func UpdateStates(ctx * Context, nodes []Node, nodes_fn NodesFn) error { - locked_nodes := NodeMap{} - err := UpdateMoreStates(ctx, nodes, locked_nodes, nodes_fn) +func UpdateStates(ctx *Context, node Node, nodes LockMap, write_fn WriteFn) error { + context := &WriteContext{ + Context: ctx, + Locked: LockMap{}, + } + err := UpdateMoreStates(context, node, nodes, nodes_fn) if err == nil { - err = WriteNodes(ctx, locked_nodes) + err = WriteNodes(context) } - for _, node := range(locked_nodes) { - node.Unlock() + for _, lock := range(context.Locked) { + lock.Node.Unlock() } + return err } // Add nodes to an existing write context and call nodes_fn with nodes locked for read -func UpdateMoreStates(ctx * Context, nodes []Node, locked_nodes NodeMap, nodes_fn NodesFn) error { - for _, node := range(nodes) { - _, locked := locked_nodes[node.ID()] - if locked == false { - node.Lock() - locked_nodes[node.ID()] = node +func UpdateMoreStates(ctx *Context, locked LockMap, node Node, new_nodes LockMap, write_fn WriteFn) error { + new_permissions := LockMap{} + for _, request := range(new_nodes) { + id := request.Node.ID() + new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} + for _, resource := range(request.Resources) { + current_permissions, exists := locked[id] + if exists == true { + already_granted := false + for _, r := range(current_permissions.Resources) { + if r == resource { + already_granted = true + break + } + } + if already_granted == true { + continue + } + } + + err := request.Node.Allowed("write", resource, node) + if err != nil { + return err + } + + tmp := new_permissions[id] + tmp.Resources = append(tmp.Resources, resource) + new_permissions[id] = tmp + } + + req_perms, exists := new_permissions[id] + if exists == true { + cur_perms, already_locked := locked[id] + if already_locked == false { + request.Node.Lock() + locked[id] = req_perms + } else { + cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) + locked[id] = cur_perms + } } } - return nodes_fn(locked_nodes) + return write_fn(locked) } // Create a new channel with a buffer the size of buffer, and register it to node with the id diff --git a/policy.go b/policy.go index c7ff66e..7fb3daa 100644 --- a/policy.go +++ b/policy.go @@ -8,11 +8,11 @@ import ( type Policy interface { Node // Returns true if the principal is allowed to perform the action on the resource - Allows(action string, resource string, principal Node) bool + Allows(resource string, action string, principal Node) bool } type NodeActions map[string][]string -func (actions NodeActions) Allows(action string, resource string) bool { +func (actions NodeActions) Allows(resource string, action string) bool { for _, a := range(actions[""]) { if a == action || a == "*" { return true @@ -108,13 +108,13 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod return &policy, nil } -func (policy *PerNodePolicy) Allows(action string, resource string, principal Node) bool { +func (policy *PerNodePolicy) Allows(resource string, action string, principal Node) bool { node_actions, exists := policy.Actions[principal.ID()] if exists == false { return false } - if node_actions.Allows(action, resource) == true { + if node_actions.Allows(resource, action) == true { return true } @@ -171,8 +171,8 @@ func LoadSimplePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &policy, nil } -func (policy *SimplePolicy) Allows(action string, resource string, principal Node) bool { - return policy.Actions.Allows(action, resource) +func (policy *SimplePolicy) Allows(resource string, action string, principal Node) bool { + return policy.Actions.Allows(resource, action) } type PerTagPolicy struct { @@ -235,7 +235,7 @@ func LoadPerTagPolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &policy, nil } -func (policy *PerTagPolicy) Allows(action string, resource string, principal Node) bool { +func (policy *PerTagPolicy) Allows(resource string, action string, principal Node) bool { user, ok := principal.(*User) if ok == false { return false @@ -244,7 +244,7 @@ func (policy *PerTagPolicy) Allows(action string, resource string, principal Nod for _, tag := range(user.Tags) { tag_actions, exists := policy.Actions[tag] if exists == true { - if tag_actions.Allows(action, resource) == true { + if tag_actions.Allows(resource, action) == true { return true } } diff --git a/signal.go b/signal.go index 3358ad0..71cce2c 100644 --- a/signal.go +++ b/signal.go @@ -15,7 +15,6 @@ const ( type GraphSignal interface { // How to propogate the signal Direction() SignalDirection - Source() NodeID Type() string String() string } @@ -23,7 +22,6 @@ type GraphSignal interface { // BaseSignal is the most basic type of signal, it has no additional data type BaseSignal struct { FDirection SignalDirection `json:"direction"` - FSource NodeID `json:"source"` FType string `json:"type"` } @@ -39,58 +37,57 @@ func (signal BaseSignal) Direction() SignalDirection { return signal.FDirection } -func (signal BaseSignal) Source() NodeID { - return signal.FSource -} - func (signal BaseSignal) Type() string { return signal.FType } -func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal { - var source_id NodeID = NodeID{} - if source != nil { - source_id = source.ID() - } - +func NewBaseSignal(_type string, direction SignalDirection) BaseSignal { signal := BaseSignal{ FDirection: direction, - FSource: source_id, FType: _type, } return signal } -func NewDownSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Down) +func NewDownSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Down) +} + +func NewSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Up) } -func NewSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Up) +func NewDirectSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Direct) } -func NewDirectSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Direct) +var AbortSignal = NewBaseSignal("abort", Down) +var CancelSignal = NewBaseSignal("cancel", Down) + +type IDSignal struct { + BaseSignal + ID NodeID `json:"id"` } -func AbortSignal(source Node) BaseSignal { - return NewBaseSignal(source, "abort", Down) +func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { + return IDSignal{ + BaseSignal: NewBaseSignal(_type, direction), + ID: id, + } } -func CancelSignal(source Node) BaseSignal { - return NewBaseSignal(source, "cancel", Down) +func NewStatusSignal(_type string, source NodeID) IDSignal { + return NewIDSignal(_type, Up, source) } type StartChildSignal struct { - BaseSignal - ChildID NodeID `json:"child_id"` + IDSignal Action string `json:"action"` } -func NewStartChildSignal(source Node, child_id NodeID, action string) StartChildSignal { +func NewStartChildSignal(child_id NodeID, action string) StartChildSignal { return StartChildSignal{ - BaseSignal: NewBaseSignal(source, "start_child", Direct), - ChildID: child_id, + IDSignal: NewIDSignal("start_child", Direct, child_id), Action: action, } } diff --git a/thread.go b/thread.go index 0e2630d..0de353e 100644 --- a/thread.go +++ b/thread.go @@ -10,32 +10,48 @@ import ( "github.com/google/uuid" ) -// SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel -func (thread * SimpleThread) Signal(ctx * Context, signal GraphSignal, nodes NodeMap) error { - err := thread.SimpleLockable.Signal(ctx, signal, nodes) +// Assumed that thread is already locked for signal +func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) error { + err := thread.SimpleLockable.Signal(context, signal) if err != nil { return err } - if signal.Direction() == Up { - // Child->Parent, thread updates parent and connected requirement - if thread.parent != nil { - UseMoreStates(ctx, []Node{thread.parent}, nodes, func(nodes NodeMap) error { - thread.parent.Signal(ctx, signal, nodes) + + switch signal.Direction() { + case Up: + err = UseMoreStates(ctx, locked, thread, NewLockMap( + NewLockRequest(thread, []string{"parent"}), + ), func(context *ReadContext) error { + if thread.parent != nil { + return UseMoreStates(ctx, locked, thread, NewLockRequest(thread.parent, []string{"signal"}), func(context *ReadContext) error { + return thread.parent.Signal(context, signal) + }) + } else { return nil - }) - } - } else if signal.Direction() == Down { - // Parent->Child, updates children and dependencies - UseMoreStates(ctx, NodeList(thread.children), nodes, func(nodes NodeMap) error { + } + }) + case Down: + err = UseMoreStates(ctx, locked, thread, NewLockMap( + NewLockRequest(thread, []string{"children"}), + RequestList(thread.childre, []string{"signal"}), + ), func(context *ReadContext) error { for _, child := range(thread.children) { - child.Signal(ctx, signal, nodes) + err := child.Signal(context, signal) + if err != nil { + return err + } } return nil }) - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + case Direct: + err = nil + default: + return fmt.Errorf("Invalid signal direction %d", signal.Direction()) + } + if err != nil { + return err } + thread.signal <- signal return nil } @@ -156,14 +172,16 @@ func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error { return nil } -func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool { +func checkIfChild(context *WriteContext, target Thread, cur Thread) bool { for _, child := range(cur.Children()) { if child.ID() == target.ID() { return true } is_child := false - UpdateMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { - is_child = checkIfChild(ctx, target, child, nodes) + UpdateMoreStates(ctx, locked, cur, NewLockMap( + NewLockRequest(child, []string{"children"}), + ), func(locked NodeLockMap) error { + is_child = checkIfChild(context, target, child) return nil }) if is_child { @@ -174,8 +192,7 @@ func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool return false } -// Requires thread and childs thread to be locked for write -func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, nodes NodeMap) error { +func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { if ctx == nil || thread == nil || child == nil { return fmt.Errorf("invalid input") } @@ -184,29 +201,34 @@ func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, no return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) } - if child.Parent() != nil { - return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) - } + return UpdateMoreStates(context, princ, NewNodeMap( + NewLockInfo(child, []string{"parent", "children"}), + NewLockInfo(thread, []string{"parent", "children"}), + ), func(context *WriteContext) { + if child.Parent() != nil { + return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) + } - if checkIfChild(ctx, thread, child, nodes) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) - } + if checkIfChild(context, thread, child) == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) + } - if checkIfChild(ctx, child, thread, nodes) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) - } + if checkIfChild(context, child, thread) == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) + } - err := thread.AddChild(child, info) - if err != nil { - return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) - } - child.SetParent(thread) + err := thread.AddChild(child, info) + if err != nil { + return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) + } + child.SetParent(thread) - if err != nil { - return err - } + if err != nil { + return err + } - return nil + return nil + }) } type ThreadAction func(* Context, Thread)(string, error) @@ -432,8 +454,8 @@ func NewSimpleThread(id NodeID, name string, state_name string, info_type reflec } } -// Requires that thread is already locked for read in UseStates -func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread { +// Requires the read permission of threads children +func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Thread { if thread == nil { panic("cannot recurse through nil") } @@ -443,8 +465,8 @@ func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread { for _, child := range thread.Children() { var result Thread = nil - UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { - result = FindChild(ctx, child, id, nodes) + UseMoreStates(context, princ, NewLockRequest(child, []string{"children"}), func(locked NodeLockMap) error { + result = FindChild(ctx, princ, child, id, locked) return nil }) if result != nil { @@ -499,12 +521,12 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error { return err } - err = UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + err = UpdateStates(ctx, thread, NewLockRequest(thread, []string{"state"}), func(locked NodeLockMap) error { err := thread.SetState("finished") if err != nil { return err } - return UnlockLockables(ctx, []Lockable{thread}, thread, nodes) + return UnlockLockables(ctx, []Lockable{thread}, thread, locked) }) if err != nil { @@ -563,23 +585,25 @@ func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lock } func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error { - return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { - child := thread.Child(signal.ChildID) + return UpdateStates(ctx, thread, NewLockRequest(thread, []string{"children"}), func(locked NodeLockMap) error { + child := thread.Child(signal.ID) if child == nil { - return fmt.Errorf("%s is not a child of %s", signal.ChildID, thread.ID()) + return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) } + return UpdateMoreStates(ctx, locked, thread, NewLockRequest(child, []string{"start"}), func(locked NodeLockMap) error { - info := thread.ChildInfo(signal.ChildID).(*ParentThreadInfo) - info.Start = true - ChildGo(ctx, thread, child, signal.Action) + info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) + info.Start = true + ChildGo(ctx, thread, child, signal.Action) - return nil + return nil + }) }) } func ThreadRestore(ctx * Context, thread Thread) { - UpdateStates(ctx, []Node{thread}, func(nodes NodeMap)(error) { - return UpdateMoreStates(ctx, NodeList(thread.Children()), nodes, func(nodes NodeMap) error { + UpdateStates(ctx, thread, NewLockRequest(thread, []string{"children"}), func(locked NodeLockMap)(error) { + return UpdateMoreStates(ctx, locked, thread, RequestList(thread.Children(), []string{"start"}), func(locked NodeLockMap) error { for _, child := range(thread.Children()) { should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) @@ -594,13 +618,13 @@ func ThreadRestore(ctx * Context, thread Thread) { } func ThreadStart(ctx * Context, thread Thread) error { - return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + return UpdateStates(ctx, thread, NewLockRequest(thread, []string{"start", "lock"}), func(locked NodeLockMap) error { owner_id := NodeID{} if thread.Owner() != nil { owner_id = thread.Owner().ID() } if owner_id != thread.ID() { - err := LockLockables(ctx, []Lockable{thread}, thread, nodes) + err := LockLockables(ctx, []Lockable{thread}, thread, locked) if err != nil { return err } @@ -628,11 +652,7 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { for { select { case signal := <- thread.SignalChannel(): - if signal.Source() == thread.ID() { - ctx.Log.Logf("thread", "THREAD_SIGNAL_INTERNAL") - } else { - ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) - } + ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) signal_fn, exists := thread.Handler(signal.Type()) if exists == true { ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) @@ -642,7 +662,7 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } case <- thread.Timeout(): timeout_action := "" - err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + err := UpdateStates(ctx, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *WriteContext) error { timeout_action = thread.TimeoutAction() thread.ClearTimeout() return nil @@ -656,36 +676,25 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } } -type ThreadAbortedError NodeID - -func (e ThreadAbortedError) Is(target error) bool { - error_type := reflect.TypeOf(ThreadAbortedError(NodeID{})) - target_type := reflect.TypeOf(target) - return error_type == target_type -} -func (e ThreadAbortedError) Error() string { - return fmt.Sprintf("Aborted by %s", (uuid.UUID)(e).String()) -} -func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { - return ThreadAbortedError(aborter) -} +var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread abort is to return a ThreadAbortedError func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - thread.Signal(ctx, NewSignal(thread, "thread_aborted"), nodes) - return nil + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, NewStatusSignal("aborted", thread.ID()), locked) }) - return "", NewThreadAbortedError(signal.Source()) + if err != nil { + return "", err + } + return "", ThreadAbortedError } // Default thread cancel is to finish the thread func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - thread.Signal(ctx, NewSignal(thread, "thread_cancelled"), nodes) - return nil + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, NewSignal("cancelled"), locked) }) - return "", nil + return "", err } func NewThreadActions() ThreadActions{ diff --git a/thread_test.go b/thread_test.go deleted file mode 100644 index 0e96c0a..0000000 --- a/thread_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package graphvent - -import ( - "testing" - "time" - "fmt" -) - -func TestNewThread(t * testing.T) { - ctx := testContext(t) - - t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - go func(thread Thread) { - time.Sleep(10*time.Millisecond) - UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - }(t1) - - err := ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) (error) { - owner := t1.owner - if owner != nil { - return fmt.Errorf("Wrong owner %+v", owner) - } - return nil - }) -} - -func TestThreadWithRequirement(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - err := UpdateStates(ctx, []Node{l1, t1}, func(nodes NodeMap) error { - return LinkLockables(ctx, t1, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - go func (thread Thread) { - time.Sleep(10*time.Millisecond) - UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - }(t1) - fatalErr(t, err) - - err = ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("Wrong owner %+v", owner) - } - return nil - }) - fatalErr(t, err) -} - -func TestThreadDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - err := UpdateStates(ctx, []Node{t1, l1}, func(nodes NodeMap) error { - return LinkLockables(ctx, t1, []Lockable{l1}, nodes) - }) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - fatalErr(t, err) - - err = ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - ser, err := t1.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - - t1_loaded, err := LoadNode(ctx, t1.ID()) - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error { - ser, err := t1_loaded.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) -} - -func TestThreadUnlink(t * testing.T) { - ctx := logTestContext(t, []string{}) - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - t2_r := NewSimpleThread(RandID(), "Test Thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) - t2 := &t2_r - - - err := UpdateStates(ctx, []Node{t1, t2}, func(nodes NodeMap) error { - err := LinkThreads(ctx, t1, t2, nil, nodes) - if err != nil { - return err - } - - return UnlinkThreads(ctx, t1, t2) - }) - fatalErr(t, err) -} -