diff --git a/context.go b/context.go index 2a130b3..d6fd5fa 100644 --- a/context.go +++ b/context.go @@ -212,7 +212,7 @@ func NewContext(db * badger.DB, log Logger) * Context { ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate) ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf) - ctx.GQL.Mutation.AddFieldConfig("sendUpdate", GQLMutationSendUpdate) + ctx.GQL.Mutation.AddFieldConfig("abort", GQLMutationAbort) ctx.GQL.Mutation.AddFieldConfig("startChild", GQLMutationStartChild) err = ctx.RebuildSchema() diff --git a/gql.go b/gql.go index 319dc97..a38cf15 100644 --- a/gql.go +++ b/gql.go @@ -204,7 +204,16 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"}) - err := UpdateStates(ctx, []Node{server, &new_user}, func(nodes NodeMap) error { + err := UpdateStates(ctx, server, NewLockMap(LockList{ + LockInfo{ + Node: server, + Resources: []string{"users"}, + }, + LockInfo{ + Node: &new_user, + Resources: []string{""}, + }, + }), func(context *WriteContext) error { server.Users[key_id] = &new_user return nil }) @@ -864,30 +873,59 @@ var gql_actions ThreadActions = ThreadActions{ }(server) - UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){ + err = UpdateStates(ctx, server, NewLockMap( + NewLockInfo(server, []string{"http_server"}), + ), func(context *WriteContext) error { server.tcp_listener = listener server.http_server = http_server - return server.Signal(ctx, NewSignal(server, "server_started"), nodes) + return nil + }) + + if err != nil { + return "", err + } + + err = UseStates(ctx, server, NewLockMap( + NewLockInfo(server, []string{"signal"}), + ), func(context *ReadContext) error { + return server.Signal(context, NewSignal("server_started")) }) + if err != nil { + return "", err + } + return "wait", nil }, } var gql_handlers ThreadHandlers = ThreadHandlers{ - "child_added": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { + "child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) - UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { - should_run, exists := thread.ChildInfo(signal.Source()).(*ParentThreadInfo) + err := UpdateStates(ctx, thread, NewLockMap( + NewLockInfo(thread, []string{"children"}), + ), func(context *WriteContext) error { + sig, ok := signal.(IDSignal) + if ok == false { + ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST") + return nil + } + should_run, exists := thread.ChildInfo(sig.ID).(*ParentThreadInfo) if exists == false { - ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") + ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) return nil } if should_run.Start == true { - ChildGo(ctx, thread, thread.Child(signal.Source()), should_run.StartAction) + ChildGo(ctx, thread, thread.Child(sig.ID), should_run.StartAction) } return nil }) + + if err != nil { + + } else { + + } return "wait", nil }, "start_child": func(ctx *Context, thread Thread, signal GraphSignal) (string, error) { @@ -902,7 +940,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ if err != nil { ctx.Log.Logf("gql", "GQL_START_CHILD_ERR: %s", err) } else { - ctx.Log.Logf("gql", "GQL_START_CHILD: %s", sig.ChildID.String()) + ctx.Log.Logf("gql", "GQL_START_CHILD: %s", sig.ID.String()) } return "wait", nil diff --git a/gql_mutation.go b/gql_mutation.go index 24f0974..b85f261 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -4,16 +4,13 @@ import ( "github.com/graphql-go/graphql" ) -var GQLMutationSendUpdate = NewField(func()*graphql.Field { - gql_mutation_send_update := &graphql.Field{ +var GQLMutationAbort = NewField(func()*graphql.Field { + gql_mutation_abort := &graphql.Field{ Type: GQLTypeSignal.Type, Args: graphql.FieldConfigArgument{ "id": &graphql.ArgumentConfig{ Type: graphql.String, }, - "signal": &graphql.ArgumentConfig{ - Type: GQLTypeSignalInput.Type, - }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { ctx, err := PrepResolve(p) @@ -21,50 +18,37 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field { return nil, err } - err = ctx.Server.Allowed("signal", "self", ctx.User) - if err != nil { - return nil, err - } - - signal_map, err := ExtractParam[map[string]interface{}](p, "signal") + err = ctx.Server.Allowed("signal", "", ctx.User) if err != nil { return nil, err } - var signal GraphSignal = nil - if signal_map["Direction"] == "up" { - signal = NewSignal(ctx.Server, signal_map["Type"].(string)) - } else if signal_map["Direction"] == "down" { - signal = NewDownSignal(ctx.Server, signal_map["Type"].(string)) - } else if signal_map["Direction"] == "direct" { - signal = NewDirectSignal(ctx.Server, signal_map["Type"].(string)) - } else { - return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) - } - id, err := ExtractID(p, "id") if err != nil { return nil, err } var node Node = nil - err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ - node = FindChild(ctx.Context, ctx.Server, id, nodes) + err = UseStates(ctx.Context, ctx.User, NewLockMap( + NewLockInfo(ctx.Server, []string{"children"}), + ), func(context *ReadContext) (error){ + node = FindChild(ctx.Context, ctx.User, ctx.Server, id, locked) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - node.Signal(ctx.Context, signal, nodes) - return nil + return UseMoreStates(ctx.Context, locked, ctx.User, NewLockInfo(node, []string{"signal"}), func(locked NodeLockMap) error { + return node.Signal(ctx.Context, AbortSignal, locked) + }) }) if err != nil { return nil, err } - return signal, nil + return AbortSignal, nil }, } - return gql_mutation_send_update + return gql_mutation_abort }) var GQLMutationStartChild = NewField(func()*graphql.Field{ @@ -88,11 +72,6 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ return nil, err } - err = ctx.Server.Allowed("start_child", "self", ctx.User) - if err != nil { - return nil, err - } - parent_id, err := ExtractID(p, "parent_id") if err != nil { return nil, err @@ -109,14 +88,22 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } var signal GraphSignal - err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ - node := FindChild(ctx.Context, ctx.Server, parent_id, nodes) + err = UseStates(ctx.Context, ctx.User, NewLockMap( + NewLockInfo(ctx.Server, []string{"children"}), + ), func(context *ReadContext) error { + node := FindChild(ctx.Context, ctx.User, ctx.Server, parent_id, locked) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) } - return UseMoreStates(ctx.Context, []Node{node}, nodes, func(NodeMap) error { - signal = NewStartChildSignal(ctx.Server, child_id, action) - return node.Signal(ctx.Context, signal, nodes) + + err := node.Allowed("signal", "", ctx.User) + if err != nil { + return err + } + + return UseMoreStates(ctx.Context, locked, ctx.User, NewLockInfo(node, []string{"start_child", "signal"}), func(locked NodeLockMap) error { + signal = NewStartChildSignal(child_id, action) + return node.Signal(ctx.Context, signal, locked) }) }) if err != nil { diff --git a/gql_query.go b/gql_query.go index ca8ea27..4593e10 100644 --- a/gql_query.go +++ b/gql_query.go @@ -11,7 +11,7 @@ var GQLQuerySelf = &graphql.Field{ return nil, err } - err = ctx.Server.Allowed("enumerate", "self", ctx.User) + err = ctx.Server.Allowed("read", "", ctx.User) if err != nil { return nil, err } @@ -28,7 +28,7 @@ var GQLQueryUser = &graphql.Field{ return nil, err } - err = ctx.User.Allowed("enumerate", "self", ctx.User) + err = ctx.User.Allowed("read", "", ctx.User) if err != nil { return nil, err } diff --git a/gql_resolvers.go b/gql_resolvers.go index 4c57376..a8d897f 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -53,8 +53,8 @@ func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { return nil, fmt.Errorf("Failed to cast source to Node") } - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { - return node.Allowed("read", "id", ctx.User) + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"id"}), func(locked NodeLockMap) error { + return nil }) if err != nil { return nil, err @@ -76,9 +76,9 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { listen := "" - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"listen"}), func(locked NodeLockMap) error { listen = node.Listen - return node.Allowed("read", "listen", ctx.User) + return nil }) if err != nil { @@ -100,9 +100,9 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { } var parent Thread = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"parent"}), func(locked NodeLockMap) error { parent = node.Parent() - return node.Allowed("read", "parent", ctx.User) + return nil }) if err != nil { @@ -124,9 +124,9 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { } var state string - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"state"}), func(locked NodeLockMap) error { state = node.State() - return node.Allowed("read", "state", ctx.User) + return nil }) if err != nil { @@ -148,9 +148,9 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { } var children []Thread = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"children"}), func(locked NodeLockMap) error { children = node.Children() - return node.Allowed("read", "children", ctx.User) + return nil }) if err != nil { @@ -172,9 +172,9 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { } name := "" - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"name"}), func(locked NodeLockMap) error { name = node.Name() - return node.Allowed("read", "name", ctx.User) + return nil }) if err != nil { @@ -196,9 +196,9 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { } var requirements []Lockable = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"requirements"}), func(locked NodeLockMap) error { requirements = node.Requirements() - return node.Allowed("read", "requirements", ctx.User) + return nil }) if err != nil { @@ -220,9 +220,9 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { } var dependencies []Lockable = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"dependencies"}), func(locked NodeLockMap) error { dependencies = node.Dependencies() - return node.Allowed("read", "dependencies", ctx.User) + return nil }) if err != nil { @@ -244,9 +244,9 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { } var owner Node = nil - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) (error) { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"owner"}), func(locked NodeLockMap) error { owner = node.Owner() - return node.Allowed("read", "owner", ctx.User) + return nil }) if err != nil { @@ -268,14 +268,14 @@ func GQLThreadUsers(p graphql.ResolveParams) (interface{}, error) { } var users []*User - err = UseStates(ctx.Context, []Node{node, ctx.User}, func(nodes NodeMap) error { + err = UseStates(ctx.Context, ctx.User, NewLockRequest(node, []string{"users"}), func(locked NodeLockMap) error { users = make([]*User, len(node.Users)) i := 0 for _, user := range(node.Users) { users[i] = user i += 1 } - return node.Allowed("read", "users", ctx.User) + return nil }) if err != nil { @@ -298,12 +298,6 @@ func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { }) } -func GQLSignalSource(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.Source(), nil - }) -} - func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ direction := signal.Direction() diff --git a/gql_test.go b/gql_test.go index 7a92aaf..385b07b 100644 --- a/gql_test.go +++ b/gql_test.go @@ -31,22 +31,24 @@ func TestGQLThread(t * testing.T) { t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) t2 := &t2_r - err = UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error { - i1 := NewParentThreadInfo(true, "start", "restore") - err := LinkThreads(ctx, gql_t, t1, &i1, nodes) - if err != nil { - return err - } - - i2 := NewParentThreadInfo(false, "start", "restore") - return LinkThreads(ctx, gql_t, t2, &i2, nodes) + err = UpdateStates(ctx, gql_t, RequestList([]Node{t1, t2}, []string{"parent"}), func(locked NodeLockMap) error { + return UpdateMoreStates(ctx, locked, gql_t, NewLockRequest(gql_t, []string{"children"}), func(locked NodeLockMap) error { + i1 := NewParentThreadInfo(true, "start", "restore") + err := LinkThreads(ctx, gql_t, t1, &i1, locked) + if err != nil { + return err + } + + i2 := NewParentThreadInfo(false, "start", "restore") + return LinkThreads(ctx, gql_t, t2, &i2, locked) + }) }) fatalErr(t, err) go func(thread Thread){ time.Sleep(10*time.Millisecond) - err := UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - return thread.Signal(ctx, CancelSignal(nil), nodes) + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) }(gql_t) @@ -81,39 +83,49 @@ func TestGQLDBLoad(t * testing.T) { gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") - err = UpdateStates(ctx, []Node{gql, t1, l1, u1, p1}, func(nodes NodeMap) error { - err := gql.AddPolicy(p1) - if err != nil { - return err - } - gql.Users[KeyID(&u1_key.PublicKey)] = u1 - err = LinkLockables(ctx, gql, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkThreads(ctx, gql, t1, &info, nodes) + err = UpdateStates(ctx, gql, NewLockRequest(gql, []string{"policies", "users", "requirements", "children"}), func(locked NodeLockMap) error { + return UpdateMoreStates(ctx, locked, gql, RequestList([]Node{u1, p1}, []string{}), func(locked NodeLockMap) error { + err := gql.AddPolicy(p1) + if err != nil { + return err + } + + gql.Users[KeyID(&u1_key.PublicKey)] = u1 + + return UpdateMoreStates(ctx, locked, gql, NewLockRequest(t1, []string{"parent"}), func(locked NodeLockMap) error { + err := LinkThreads(ctx, gql, t1, &info, locked) + if err != nil { + return err + } + return UpdateMoreStates(ctx, locked, gql, NewLockRequest(l1, []string{"dependencies"}), func(locked NodeLockMap) error { + return LinkLockables(ctx, gql, gql, []Lockable{l1}, locked) + }) + }) + }) }) fatalErr(t, err) - err = UseStates(ctx, []Node{gql}, func(nodes NodeMap) error { - err := gql.Signal(ctx, NewSignal(t1, "child_added"), nodes) + err = UseStates(ctx, gql, NewLockRequest(gql, []string{"signal"}), func(locked NodeLockMap) error { + err := gql.Signal(ctx, NewStatusSignal("child_linked", t1.ID()), locked) if err != nil { return nil } - return gql.Signal(ctx, AbortSignal(nil), nodes) + return gql.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") - if errors.Is(err, NewThreadAbortedError(NodeID{})) { + if errors.Is(err, ThreadAbortedError) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) - } else { + } else if err != nil{ fatalErr(t, err) + } else { + ctx.Log.Logf("test", "Main thread cancelled by signal") } - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", t1, 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") - err = UseStates(ctx, []Node{gql, u1}, func(nodes NodeMap) error { + err = UseStates(ctx, gql, RequestList([]Node{gql, u1}, nil), func(locked NodeLockMap) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -126,29 +138,29 @@ func TestGQLDBLoad(t * testing.T) { var t1_loaded *SimpleThread = nil var update_channel_2 chan GraphSignal - err = UseStates(ctx, []Node{gql_loaded}, func(nodes NodeMap) error { + err = UseStates(ctx, gql, NewLockRequest(gql_loaded, []string{"users", "children"}), func(locked NodeLockMap) error { ser, err := gql_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) - err = UseMoreStates(ctx, []Node{u_loaded}, nodes, func(nodes NodeMap) error { + err = UseMoreStates(ctx, locked, gql, NewLockRequest(u_loaded, nil), func(locked NodeLockMap) error { ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) - gql_loaded.Signal(ctx, AbortSignal(nil), nodes) + gql_loaded.Signal(ctx, AbortSignal, locked) return err }) err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") - if errors.Is(err, NewThreadAbortedError(NodeID{})) { + if errors.Is(err, ThreadAbortedError) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) } else { fatalErr(t, err) } - (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", t1_loaded, 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") + (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") } @@ -157,20 +169,21 @@ func TestGQLAuth(t * testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - p1_r := NewPerTagPolicy(RandID(), map[string]NodeActions{"gql": NewNodeActions(nil, []string{"*"})}) + p1_r := NewPerTagPolicy(RandID(), map[string]NodeActions{"gql": NewNodeActions(nil, []string{"read"})}) p1 := &p1_r gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r - err = UpdateStates(ctx, []Node{gql_t, p1}, func(nodes NodeMap) error { + // p1 not written to DB, TODO: update write to follow links maybe + err = UpdateStates(ctx, gql_t, NewLockRequest(gql_t, []string{"policies"}), func(locked NodeLockMap) error { return gql_t.AddPolicy(p1) }) done := make(chan error, 1) var update_channel chan GraphSignal - err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { + err = UseStates(ctx, gql_t, NewLockRequest(gql_t, nil), func(locked NodeLockMap) error { update_channel = UpdateChannel(gql_t, 10, NodeID{}) return nil }) @@ -184,14 +197,14 @@ func TestGQLAuth(t * testing.T) { case <-done: ctx.Log.Logf("test", "DONE") } - err := UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { - return thread.Signal(ctx, CancelSignal(nil), nodes) + err := UseStates(ctx, gql_t, NewLockRequest(gql_t, []string{"signal}"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, CancelSignal, locked) }) fatalErr(t, err) }(done, gql_t) go func(thread Thread){ - (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port ctx.Log.Logf("test", "GQL_PORT: %d", port) diff --git a/gql_types.go b/gql_types.go index f8bb852..ecbf1db 100644 --- a/gql_types.go +++ b/gql_types.go @@ -162,7 +162,7 @@ var GQLTypeGraphNode = NewSingleton(func() *graphql.Object { var GQLTypeSignal = NewSingleton(func() *graphql.Object { gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ - Name: "SignalOut", + Name: "Signal", IsTypeOf: func(p graphql.IsTypeOfParams) bool { _, ok := p.Value.(GraphSignal) return ok @@ -174,10 +174,6 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object { Type: graphql.String, Resolve: GQLSignalType, }) - gql_type_signal.AddFieldConfig("Source", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalSource, - }) gql_type_signal.AddFieldConfig("Direction", &graphql.Field{ Type: graphql.String, Resolve: GQLSignalDirection, @@ -189,20 +185,3 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object { return gql_type_signal }, nil) -var GQLTypeSignalInput = NewSingleton(func()*graphql.InputObject { - gql_type_signal_input := graphql.NewInputObject(graphql.InputObjectConfig{ - Name: "SignalIn", - Fields: graphql.InputObjectConfigFieldMap{}, - }) - gql_type_signal_input.AddFieldConfig("Type", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - DefaultValue: "cancel", - }) - gql_type_signal_input.AddFieldConfig("Direction", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - DefaultValue: "down", - }) - - return gql_type_signal_input -}, nil) - diff --git a/graph_test.go b/graph_test.go index 5ab8c88..54a7521 100644 --- a/graph_test.go +++ b/graph_test.go @@ -12,7 +12,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, source Node, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, timeout time.Duration, str string) GraphSignal { timeout_channel := time.After(timeout) for true { select { @@ -22,16 +22,7 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si t.Fatal(str) } if signal.Type() == signal_type { - ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) - if source == nil { - if signal.Source() == ZeroID { - return signal - } - } else { - if signal.Source() == source.ID() { - return signal - } - } + return signal } case <-timeout_channel: pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) diff --git a/lockable.go b/lockable.go index 08b70b4..9bdee4c 100644 --- a/lockable.go +++ b/lockable.go @@ -215,52 +215,58 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { return nil } -// Lockable.Signal sends the update to the owner, requirements, and dependencies before updating listeners -func (lockable * SimpleLockable) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { - err := lockable.GraphNode.Signal(ctx, signal, nodes) +// Assumed that lockable is already locked for signal +func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal) error { + err := lockable.GraphNode.Signal(ctx, signal, locked) if err != nil { return err } - if signal.Direction() == Up { - // Child->Parent, lockable updates dependency lockables - owner_sent := false - UseMoreStates(ctx, NodeList(lockable.dependencies), nodes, func(nodes NodeMap) error { - for _, dependency := range(lockable.dependencies){ + + switch signal.Direction() { + case Up: + err = UseMoreStates(ctx, locked, lockable, NewLockMap( + NewLockInfo(lockable, []string{"dependencies", "owner"}), + RequestList(lockable.requirements, []string{"signal"}), + ), func(context *ReadContext) error { + owner_sent := false + for _, dependency := range(lockable.dependencies) { ctx.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) - dependency.Signal(ctx, signal, nodes) + dependency.Signal(ctx, signal, locked) if lockable.owner != nil { if dependency.ID() == lockable.owner.ID() { owner_sent = true } } } + if lockable.owner != nil && owner_sent == false { + if lockable.owner.ID() != lockable.ID() { + ctx.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) + return UseMoreStates(context, lockable, NewLockRequest(lockable.owner, []string{"signal"}), func(context *ReadContext) error { + return lockable.owner.Signal(ctx, signal, locked) + }) + } + } return nil }) - if lockable.owner != nil && owner_sent == false { - if lockable.owner.ID() != lockable.ID() { - ctx.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - UseMoreStates(ctx, []Node{lockable.owner}, nodes, func(nodes NodeMap) error { - return lockable.owner.Signal(ctx, signal, nodes) - }) - } - } - } else if signal.Direction() == Down { - // Parent->Child, lockable updates lock holder - UseMoreStates(ctx, NodeList(lockable.requirements), nodes, func(nodes NodeMap) error { + case Down: + err = UseMoreStates(context, lockable, NewLockMap( + NewLockInfo(lockable, []string{"requirements"}, + RequestList(lockable.requirements, []string{"signal"})), + ), func(context *ReadContext) error { for _, requirement := range(lockable.requirements) { - err := requirement.Signal(ctx, signal, nodes) + err := requirement.Signal(ctx, signal, locked) if err != nil { return err } } return nil }) - - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + case Direct: + err = nil + default: + return fmt.Errorf("invalid signal direction %d", signal.Direction()) } - return nil + return err } // Removes requirement as a requirement from lockable @@ -286,7 +292,7 @@ func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) err // Link requirements as requirements to lockable // Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks -func LinkLockables(ctx * Context, lockable Lockable, requirements []Lockable, nodes NodeMap) error { +func LinkLockables(context *WriteContext, princ Node, lockable Lockable, requirements []Lockable) error { if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } @@ -312,56 +318,61 @@ func LinkLockables(ctx * Context, lockable Lockable, requirements []Lockable, no found[requirement.ID()] = true } - // Check that all the requirements can be added - // If the lockable is already locked, need to lock this resource as well before we can add it - for _, requirement := range(requirements) { - for _, req := range(requirements) { - if req.ID() == requirement.ID() { - continue + return UpdateMoreStates(context, princ, NewInfoMap( + NewLockInfo(lockable, []string{"requirements"}), + RequestList(requirements, []string{"dependencies"}), + ), func(context *WriteContext) error { + // Check that all the requirements can be added + // If the lockable is already locked, need to lock this resource as well before we can add it + for _, requirement := range(requirements) { + for _, req := range(requirements) { + if req.ID() == requirement.ID() { + continue + } + if checkIfRequirement(context, req, requirement) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + } } - if checkIfRequirement(ctx, req, requirement, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + if checkIfRequirement(context, lockable, requirement) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) } - } - if checkIfRequirement(ctx, lockable, requirement, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) - } - if checkIfRequirement(ctx, requirement, lockable, nodes) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) - } - if lockable.Owner() == nil { - // If the new owner isn't locked, we can add the requirement - } else if requirement.Owner() == nil { - // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) - } else { - // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement.Owner().ID() != lockable.Owner().ID() { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + if checkIfRequirement(context, requirement, lockable) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + } + if lockable.Owner() == nil { + // If the new owner isn't locked, we can add the requirement + } else if requirement.Owner() == nil { + // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) + } else { + // If the new requirement is already locked and the owner is already locked, their owners need to match + if requirement.Owner().ID() != lockable.Owner().ID() { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + } } } - } - // Update the states of the requirements - for _, requirement := range(requirements) { - requirement.AddDependency(lockable) - lockable.AddRequirement(requirement) - ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) - } + // Update the states of the requirements + for _, requirement := range(requirements) { + requirement.AddDependency(lockable) + lockable.AddRequirement(requirement) + ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) + } - // Return no error - return nil + // Return no error + return nil + }) } // Must be called withing update context -func checkIfRequirement(ctx * Context, r Lockable, cur Lockable, nodes NodeMap) bool { +func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool { for _, c := range(cur.Requirements()) { if c.ID() == r.ID() { return true } is_requirement := false - UpdateMoreStates(ctx, []Node{c}, nodes, func(nodes NodeMap) (error) { - is_requirement = checkIfRequirement(ctx, cur, c, nodes) + UpdateMoreStates(context, cur, NewLockMap(NewLockRequest(c, []string{"requirements"})), func(context *WriteContext) error { + is_requirement = checkIfRequirement(context, cur, c) return nil }) @@ -374,19 +385,18 @@ func checkIfRequirement(ctx * Context, r Lockable, cur Lockable, nodes NodeMap) } // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error -// Requires that all the nodes in to_lock and new_owner are locked for write -func LockLockables(ctx * Context, to_lock []Lockable, new_owner Lockable, nodes NodeMap) error { +// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call +func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } - node_list := make([]Node, len(to_lock)) - for i, l := range(to_lock) { + for _, l := range(to_lock) { if l == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") } - node_list[i] = l } + if new_owner == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") } @@ -396,132 +406,120 @@ func LockLockables(ctx * Context, to_lock []Lockable, new_owner Lockable, nodes return nil } - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_lock) { - ctx.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) + return UpdateMoreStates(ctx, locked, new_owner, RequestList(to_lock, []string{"lock"}), func(locked NodeLockMap) error { + // First loop is to check that the states can be locked, and locks all requirements + for _, req := range(to_lock) { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) - // Check custom lock conditions - err := req.CanLock(new_owner) - if err != nil { - return err - } + // Check custom lock conditions + err := req.CanLock(new_owner) + if err != nil { + return err + } - // If req is alreay locked, check that we can pass the lock - if req.Owner() != nil { - owner := req.Owner() - if owner.ID() == new_owner.ID() { - // If we already own the lock, nothing to do - continue - } else if owner.ID() == req.ID() { - if req.AllowedToTakeLock(new_owner, req) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) - } - err := LockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err + // If req is alreay locked, check that we can pass the lock + if req.Owner() != nil { + owner := req.Owner() + if owner.ID() == new_owner.ID() { + continue + } else { + err := UpdateMoreStates(ctx, locked, new_owner, NewLockRequest(owner, []string{"take_lock"}), func(locked NodeLockMap)(error){ + return LockLockables(ctx, req.Requirements(), req, locked) + }) + if err != nil { + return err + } } } else { - err := UpdateMoreStates(ctx, []Node{owner}, nodes, func(nodes NodeMap)(error){ - if owner.AllowedToTakeLock(new_owner, req) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", new_owner.ID(), req.ID(), owner.ID()) - } - err := LockLockables(ctx, req.Requirements(), req, nodes) - return err - }) + err := LockLockables(ctx, req.Requirements(), req, locked) if err != nil { return err } } - } else { - err := LockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err - } } - } - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_lock) { - old_owner := req.Owner() - // If the lockable was previously unowned, update the state - if old_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) - req.SetOwner(new_owner) - new_owner.RecordLock(req, old_owner) - // Otherwise if the new owner already owns it, no need to update state - } else if old_owner.ID() == new_owner.ID() { - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) - // Otherwise update the state - } else { - req.SetOwner(new_owner) - new_owner.RecordLock(req, old_owner) - ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) + // At this point state modification will be started, so no errors can be returned + for _, req := range(to_lock) { + old_owner := req.Owner() + // If the lockable was previously unowned, update the state + if old_owner == nil { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) + req.SetOwner(new_owner) + new_owner.RecordLock(req, old_owner) + // Otherwise if the new owner already owns it, no need to update state + } else if old_owner.ID() == new_owner.ID() { + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) + // Otherwise update the state + } else { + req.SetOwner(new_owner) + new_owner.RecordLock(req, old_owner) + context.Context.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) + } } - } - return nil + return nil + }) + } -// Unlock nodes in the to_unlock slice with old_owner, does not modify any states if returning an error -// Requires that all the nodes in to_unlock and old_owner are locked for write -func UnlockLockables(ctx * Context, to_unlock []Lockable, old_owner Lockable, nodes NodeMap) error { +func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lockable) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } + for _, l := range(to_unlock) { if l == nil { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not lock nil") + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil") } } + if old_owner == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") } - // Called with no requirements to lock, success + // Called with no requirements to unlock, success if len(to_unlock) == 0 { return nil } - node_list := make([]Node, len(to_unlock)) - for i, l := range(to_unlock) { - node_list[i] = l - } + return UpdateMoreStates(ctx, locked, old_owner, RequestList(to_unlock, []string{"lock"}), func(locked NodeLockMap) error { + // First loop is to check that the states can be locked, and locks all requirements + for _, req := range(to_unlock) { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_unlock) { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) + // Check if the owner is correct + if req.Owner() != nil { + if req.Owner().ID() != old_owner.ID() { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) + } + } else { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) + } - // Check if the owner is correct - if req.Owner() != nil { - if req.Owner().ID() != old_owner.ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) + // Check custom unlock conditions + err := req.CanUnlock(old_owner) + if err != nil { + return err } - } else { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) - } - // Check custom unlock conditions - err := req.CanUnlock(old_owner) - if err != nil { - return err + err = UnlockLockables(ctx, req.Requirements(), req, locked) + if err != nil { + return err + } } - err = UnlockLockables(ctx, req.Requirements(), req, nodes) - if err != nil { - return err + // At this point state modification will be started, so no errors can be returned + for _, req := range(to_unlock) { + new_owner := old_owner.RecordUnlock(req) + req.SetOwner(new_owner) + if new_owner == nil { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) + } else { + context.Context.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) + } } - } - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_unlock) { - new_owner := old_owner.RecordUnlock(req) - req.SetOwner(new_owner) - if new_owner == nil { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) - } else { - ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) - } - } - return nil + return nil + }) } // Load function for SimpleLockable diff --git a/lockable_test.go b/lockable_test.go deleted file mode 100644 index 2be1608..0000000 --- a/lockable_test.go +++ /dev/null @@ -1,495 +0,0 @@ -package graphvent - -import ( - "testing" - "fmt" - "time" -) - -func TestNewSimpleLockable(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - l1_deps := len(l1.Dependencies()) - if l1_deps != 1 { - return fmt.Errorf("l1 has wront amount of dependencies %d/1", l1_deps) - } - - l1_dep1 := l1.Dependencies()[0] - if l1_dep1.ID() != l2.ID() { - return fmt.Errorf("Wrong dependency for l1, %s instead of %s", l1_dep1.ID(), l2.ID()) - } - - l2_reqs := len(l2.Requirements()) - if l2_reqs != 1 { - return fmt.Errorf("l2 has wrong amount of requirements %d/1", l2_reqs) - } - - l2_req1 := l2.Requirements()[0] - if l2_req1.ID() != l1.ID() { - return fmt.Errorf("Wrong requirement for l2, %s instead of %s", l2_req1.ID(), l1.ID()) - } - return nil - }) - fatalErr(t, err) -} - -func TestRepeatedChildLockable(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1, l1}, nodes) - }) - - if err == nil { - t.Fatal("Added the same lockable as a requirement twice to the same lockable") - } -} - -func TestLockableSelfLock(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - - err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l1.ID() { - return fmt.Errorf("l1 is owned by %s instead of self", owner_id) - } - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - if l1.owner != nil { - return fmt.Errorf("l1 is not unowned after unlock: %s", l1.owner.ID()) - } - return nil - }) - - fatalErr(t, err) -} - -func TestLockableSelfLockTiered(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { - owner_1 := NodeID{} - if l1.owner != nil { - owner_1 = l1.owner.ID() - } - if owner_1 != l3.ID() { - return fmt.Errorf("l1 is owned by %s instead of l3", owner_1) - } - - owner_2 := NodeID{} - if l2.owner != nil { - owner_2 = l2.owner.ID() - } - if owner_2 != l3.ID() { - return fmt.Errorf("l2 is owned by %s instead of l3", owner_2) - } - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { - owner_1 := l1.owner - if owner_1 != nil { - return fmt.Errorf("l1 is not unowned after unlocking: %s", owner_1.ID()) - } - - owner_2 := l2.owner - if owner_2 != nil { - return fmt.Errorf("l2 is not unowned after unlocking: %s", owner_2.ID()) - } - - owner_3 := l3.owner - if owner_3 != nil { - return fmt.Errorf("l3 is not unowned after unlocking: %s", owner_3.ID()) - } - return nil - }) - - fatalErr(t, err) -} - -func TestLockableLockOther(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{l1}, l2, nodes) - fatalErr(t, err) - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l2.ID() { - return fmt.Errorf("l1 is owned by %s instead of l2", owner_id) - } - - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { - err := UnlockLockables(ctx, []Lockable{l1}, l2, nodes) - fatalErr(t, err) - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("l1 is owned by %s instead of l2", owner.ID()) - } - - return nil - }) - fatalErr(t, err) - -} - -func TestLockableLockSimpleConflict(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) (error) { - err := LockLockables(ctx, []Lockable{l1}, l2, nodes) - if err == nil { - t.Fatal("l2 took l1's lock from itself") - } - - return nil - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID{} - if l1.owner != nil { - owner_id = l1.owner.ID() - } - if owner_id != l1.ID() { - return fmt.Errorf("l1 is owned by %s instead of l1", owner_id) - } - - return nil - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return UnlockLockables(ctx, []Lockable{l1}, l1, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("l1 is owned by %s instead of l1", owner.ID()) - } - - return nil - }) - fatalErr(t, err) - -} - -func TestLockableLockTieredConflict(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l2}, l2, nodes) - }) - fatalErr(t, err) - - err = UpdateStates(ctx, []Node{l3}, func(nodes NodeMap) error { - return LockLockables(ctx, []Lockable{l3}, l3, nodes) - }) - if err == nil { - t.Fatal("Locked l3 which depends on l1 while l2 which depends on l1 is already locked") - } -} - -func TestLockableSimpleUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - - - update_channel := UpdateChannel(l1, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return l1.Signal(ctx, NewDirectSignal(l1, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1") -} - -func TestLockableDownUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l2}, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l1, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return l2.Signal(ctx, NewDownSignal(l2, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") -} - -func TestLockableUpUpdate(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LinkLockables(ctx, l3, []Lockable{l2}, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l3, 1, NodeID{}) - - go func() { - UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { - return l2.Signal(ctx, NewSignal(l2, "test_update"), nodes) - }) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") -} - -func TestOwnerNotUpdatedTwice(t * testing.T) { - ctx := logTestContext(t, []string{}) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l2}, l2, nodes) - }) - fatalErr(t, err) - - update_channel := UpdateChannel(l2, 1, NodeID{}) - - go func() { - err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { - return l1.Signal(ctx, NewSignal(l1, "test_update"), nodes) - }) - fatalErr(t, err) - }() - - (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") - (*GraphTester)(t).CheckForNone(update_channel, "Second update received on dependency") -} - -func TestLockableDependencyOverlap(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - - err := UpdateStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l2, []Lockable{l1}, nodes) - if err != nil { - return err - } - - return LinkLockables(ctx, l3, []Lockable{l2, l1}, nodes) - }) - if err == nil { - t.Fatal("Should have thrown an error because of dependency overlap") - } -} - -func TestLockableDBLoad(t * testing.T){ - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - l3_r := NewSimpleLockable(RandID(), "Test Lockable 3") - l3 := &l3_r - l4_r := NewSimpleLockable(RandID(), "Test Lockable 4") - l4 := &l4_r - l5_r := NewSimpleLockable(RandID(), "Test Lockable 5") - l5 := &l5_r - l6_r := NewSimpleLockable(RandID(), "Test Lockable 6") - l6 := &l6_r - err := UpdateStates(ctx, []Node{l1, l2, l3, l4, l5, l6}, func(nodes NodeMap) error { - err := LinkLockables(ctx, l3, []Lockable{l1, l2}, nodes) - if err != nil { - return err - } - - err = LinkLockables(ctx, l4, []Lockable{l3}, nodes) - if err != nil { - return err - } - - err = LinkLockables(ctx, l5, []Lockable{l4}, nodes) - if err != nil { - return err - } - return LockLockables(ctx, []Lockable{l3}, l6, nodes) - }) - fatalErr(t, err) - - err = UseStates(ctx, []Node{l3}, func(nodes NodeMap) error { - ser, err := l3.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - fatalErr(t, err) - - l3_loaded, err := LoadNode(ctx, l3.ID()) - fatalErr(t, err) - - // TODO: add more equivalence checks - err = UseStates(ctx, []Node{l3_loaded}, func(nodes NodeMap) error { - ser, err := l3_loaded.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - fatalErr(t, err) -} - -func TestLockableUnlink(t * testing.T){ - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") - l2 := &l2_r - - err := UpdateStates(ctx, []Node{l1, l2}, func(nodes NodeMap) error { - return LinkLockables(ctx, l2, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - err = UnlinkLockables(ctx, l2, l1) - fatalErr(t, err) -} diff --git a/node.go b/node.go index 08b7cbd..8d93487 100644 --- a/node.go +++ b/node.go @@ -54,6 +54,24 @@ func RandID() NodeID { return NodeID(uuid.New()) } +// A Node represents a data that can be locked and held by other Nodes +type Node interface { + ID() NodeID + Type() NodeType + Serialize() ([]byte, error) + + Allows(resouce, action string, principal Node) error + AddPolicy(Policy) error + RemovePolicy(Policy) error + + Signal(context *ReadContext, signal GraphSignal) error + + + Requirements() []Node + Dependencies() []Node + Owner() Node +} + // A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) // RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node type Node interface { @@ -74,7 +92,7 @@ type Node interface { RemovePolicy(Policy) error // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error + Signal(context *ReadContext, signal GraphSignal) error // Register a channel to receive updates sent to the node RegisterChannel(id NodeID, listener chan GraphSignal) // Unregister a channel from receiving updates sent to the node @@ -193,20 +211,20 @@ func (node * GraphNode) Type() NodeType { // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it -func (node * GraphNode) Signal(ctx *Context, signal GraphSignal, nodes NodeMap) error { - ctx.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) +func (node * GraphNode) Signal(context *ReadContext, signal GraphSignal) error { + context.Context.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) node.listeners_lock.Lock() defer node.listeners_lock.Unlock() closed := []NodeID{} for id, listener := range node.listeners { - ctx.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) + context.Context.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) select { case listener <- signal: default: - ctx.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) + context.Context.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) go func(node Node, listener chan GraphSignal) { - listener <- NewDirectSignal(node, "listener_closed") + listener <- NewDirectSignal("listener_closed") close(listener) }(node, listener) closed = append(closed, id) @@ -293,18 +311,19 @@ func getNodeBytes(node Node) ([]byte, error) { } // Write multiple nodes to the database in a single transaction -func WriteNodes(ctx * Context, nodes NodeMap) error { - ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) - if nodes == nil { +func WriteNodes(context *WriteContext) error { + if locked == nil { return fmt.Errorf("Cannot write nil map to DB") } + context.Context.Log.Logf("db", "DB_WRITES: %d", len(context.Locked)) - serialized_bytes := make([][]byte, len(nodes)) - serialized_ids := make([][]byte, len(nodes)) + serialized_bytes := make([][]byte, len(context.Locked)) + serialized_ids := make([][]byte, len(context.Locked)) i := 0 - for _, node := range(nodes) { + for _, lock := range(context.Locked) { + node := lock.Node node_bytes, err := getNodeBytes(node) - ctx.Log.Logf("db", "DB_WRITE: %+v", node) + context.Context.Log.Logf("db", "DB_WRITE: %+v", node) if err != nil { return err } @@ -317,7 +336,7 @@ func WriteNodes(ctx * Context, nodes NodeMap) error { i++ } - err := ctx.DB.Update(func(txn *badger.Txn) error { + err := context.Context.DB.Update(func(txn *badger.Txn) error { for i, id := range(serialized_ids) { err := txn.Set(id, serialized_bytes[i]) if err != nil { @@ -406,60 +425,165 @@ func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) { return node, nil } -// Internal function to filter duplicate nodes from a list -func filterDuplicates(nodes []Node) []Node { - ret := []Node{} - found := map[NodeID]bool{} - for _, node := range(nodes) { - if node == nil { - return []Node{} - } +func NewLockInfo(node Node, resources []string) LockInfo { + return LockInfo{ + Node: node, + Resources: resources, + } +} - _, exists := found[node.ID()] - if exists == false { - found[node.ID()] = true - ret = append(ret, node) +type LockInfoList interface { + List() []LockInfo +} + +func NewLockMap(requests ...LockInfoList) LockMap { + reqs := LockMap{} + for _, req := range(requests) { + for _, info := range(req) { + reqs[req.Node.ID()] = info } } - return ret + return reqs } -// Convert any slice of types that implement Node to a []Node -func NodeList[K Node](list []K) []Node { - nodes := make([]Node, len(list)) +func RequestList[K Node](list []K, resources []string) LockList { + requests := make(LockList{}, len(list)) for i, node := range(list) { - nodes[i] = node + requests[i] = LockInfo{ + Node: node, + Resources: resources, + } } - return nodes + return requests } + type NodeMap map[NodeID]Node -type NodesFn func(nodes NodeMap) error -// Initiate a read context for nodes and call nodes_fn with init_nodes locked for read -func UseStates(ctx * Context, init_nodes []Node, nodes_fn NodesFn) error { - nodes := NodeMap{} - return UseMoreStates(ctx, init_nodes, nodes, nodes_fn) + +type LockInfo struct { + Node Node + Resources []string +} +func (info LockInfo) List() []LockInfo { + return []LockInfo{info} +} + +type LockMap map[NodeID]LockInfo +func (m LockMap) List() []LockInfo { + infos := make([]LockInfo, len(m)) + i := 0 + for _, info := range(m) { + infos[i] = info + i += 1 + } + + return infos } +type LockList []LockInfo +func (li LockList) List() []LockInfo { + return li +} -// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read -func UseMoreStates(ctx * Context, new_nodes []Node, nodes NodeMap, nodes_fn NodesFn) error { - new_nodes = filterDuplicates(new_nodes) +type ReadContext struct { + Graph *Context + Locked LockMap +} +type ReadFn func(*ReadContext)(error) + +type WriteContext struct { + Graph *Context + Locked LockMap +} +type WriteFn func(*WriteContext)(error) +func del[K comparable](list []K, val K) []K { + idx := -1 + for i, v := range(list) { + if v == val { + idx = i + break + } + } + if idx == -1 { + return nil + } + + list[idx] = list[len(list)-1] + return list[:len(list)-1] +} + +// Start a read context for node under ctx for the resources specified in init_nodes, then run nodes_fn +func UseStates(ctx *Context, node Node, nodes LockMap, read_fn ReadFn) error { + context := &ReadContext{ + Context: ctx, + Locked: LockMap{}, + } + return UseMoreStates(context, node, nodes, read_fn) +} + +// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read +// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read +func UseMoreStates(context *ReadContext, node Node, new_nodes LockMap, read_fn ReadFn) error { locked_nodes := []Node{} - for _, node := range(new_nodes) { - _, locked := nodes[node.ID()] - if locked == false { - node.RLock() - nodes[node.ID()] = node - locked_nodes = append(locked_nodes, node) + new_permissions := LockMap{} + for _, request := range(new_nodes) { + id := request.Node.ID() + new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} + for _, resource := range(request.Resources) { + // If the permission for this resource is already granted, continue + current_permissions, exists := context.Locked[id] + if exists == true { + already_granted := false + for _, r := range(current_permissions.Resources) { + if r == resource { + already_granted = true + break + } + } + if already_granted == true { + continue + } + } + + err := request.Node.Allowed("read", resource, node) + if err != nil { + return err + } + + tmp := new_permissions[id] + tmp.Resources = append(tmp.Resources, resource) + new_permissions[id] = tmp + } + + + req_perms, exists := new_permissions[id] + if exists == true { + cur_perms, already_locked := context.Locked[id] + if already_locked == false { + request.Node.RLock() + context.Locked[id] = req_perms + locked_nodes = append(locked_nodes, request.Node) + } else { + cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) + } } } - err := nodes_fn(nodes) + err := read_fn(context) + + for _, request := range(new_permissions) { + cur_perms := context.Locked[request.Node.ID()] + new_perms := cur_perms.Resources + for _, resource := range(cur_perms.Resources) { + new_perms = del(new_perms, resource) + } + cur_perms.Resources = new_perms + context.Locked[request.Node.ID()].Resources = new_perms + } for _, node := range(locked_nodes) { - delete(nodes, node.ID()) + delete(context.Locked, node.ID()) node.RUnlock() } @@ -467,30 +591,68 @@ func UseMoreStates(ctx * Context, new_nodes []Node, nodes NodeMap, nodes_fn Node } // Initiate a write context for nodes and call nodes_fn with nodes locked for read -func UpdateStates(ctx * Context, nodes []Node, nodes_fn NodesFn) error { - locked_nodes := NodeMap{} - err := UpdateMoreStates(ctx, nodes, locked_nodes, nodes_fn) +func UpdateStates(ctx *Context, node Node, nodes LockMap, write_fn WriteFn) error { + context := &WriteContext{ + Context: ctx, + Locked: LockMap{}, + } + err := UpdateMoreStates(context, node, nodes, nodes_fn) if err == nil { - err = WriteNodes(ctx, locked_nodes) + err = WriteNodes(context) } - for _, node := range(locked_nodes) { - node.Unlock() + for _, lock := range(context.Locked) { + lock.Node.Unlock() } + return err } // Add nodes to an existing write context and call nodes_fn with nodes locked for read -func UpdateMoreStates(ctx * Context, nodes []Node, locked_nodes NodeMap, nodes_fn NodesFn) error { - for _, node := range(nodes) { - _, locked := locked_nodes[node.ID()] - if locked == false { - node.Lock() - locked_nodes[node.ID()] = node +func UpdateMoreStates(ctx *Context, locked LockMap, node Node, new_nodes LockMap, write_fn WriteFn) error { + new_permissions := LockMap{} + for _, request := range(new_nodes) { + id := request.Node.ID() + new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} + for _, resource := range(request.Resources) { + current_permissions, exists := locked[id] + if exists == true { + already_granted := false + for _, r := range(current_permissions.Resources) { + if r == resource { + already_granted = true + break + } + } + if already_granted == true { + continue + } + } + + err := request.Node.Allowed("write", resource, node) + if err != nil { + return err + } + + tmp := new_permissions[id] + tmp.Resources = append(tmp.Resources, resource) + new_permissions[id] = tmp + } + + req_perms, exists := new_permissions[id] + if exists == true { + cur_perms, already_locked := locked[id] + if already_locked == false { + request.Node.Lock() + locked[id] = req_perms + } else { + cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) + locked[id] = cur_perms + } } } - return nodes_fn(locked_nodes) + return write_fn(locked) } // Create a new channel with a buffer the size of buffer, and register it to node with the id diff --git a/policy.go b/policy.go index c7ff66e..7fb3daa 100644 --- a/policy.go +++ b/policy.go @@ -8,11 +8,11 @@ import ( type Policy interface { Node // Returns true if the principal is allowed to perform the action on the resource - Allows(action string, resource string, principal Node) bool + Allows(resource string, action string, principal Node) bool } type NodeActions map[string][]string -func (actions NodeActions) Allows(action string, resource string) bool { +func (actions NodeActions) Allows(resource string, action string) bool { for _, a := range(actions[""]) { if a == action || a == "*" { return true @@ -108,13 +108,13 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod return &policy, nil } -func (policy *PerNodePolicy) Allows(action string, resource string, principal Node) bool { +func (policy *PerNodePolicy) Allows(resource string, action string, principal Node) bool { node_actions, exists := policy.Actions[principal.ID()] if exists == false { return false } - if node_actions.Allows(action, resource) == true { + if node_actions.Allows(resource, action) == true { return true } @@ -171,8 +171,8 @@ func LoadSimplePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &policy, nil } -func (policy *SimplePolicy) Allows(action string, resource string, principal Node) bool { - return policy.Actions.Allows(action, resource) +func (policy *SimplePolicy) Allows(resource string, action string, principal Node) bool { + return policy.Actions.Allows(resource, action) } type PerTagPolicy struct { @@ -235,7 +235,7 @@ func LoadPerTagPolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &policy, nil } -func (policy *PerTagPolicy) Allows(action string, resource string, principal Node) bool { +func (policy *PerTagPolicy) Allows(resource string, action string, principal Node) bool { user, ok := principal.(*User) if ok == false { return false @@ -244,7 +244,7 @@ func (policy *PerTagPolicy) Allows(action string, resource string, principal Nod for _, tag := range(user.Tags) { tag_actions, exists := policy.Actions[tag] if exists == true { - if tag_actions.Allows(action, resource) == true { + if tag_actions.Allows(resource, action) == true { return true } } diff --git a/signal.go b/signal.go index 3358ad0..71cce2c 100644 --- a/signal.go +++ b/signal.go @@ -15,7 +15,6 @@ const ( type GraphSignal interface { // How to propogate the signal Direction() SignalDirection - Source() NodeID Type() string String() string } @@ -23,7 +22,6 @@ type GraphSignal interface { // BaseSignal is the most basic type of signal, it has no additional data type BaseSignal struct { FDirection SignalDirection `json:"direction"` - FSource NodeID `json:"source"` FType string `json:"type"` } @@ -39,58 +37,57 @@ func (signal BaseSignal) Direction() SignalDirection { return signal.FDirection } -func (signal BaseSignal) Source() NodeID { - return signal.FSource -} - func (signal BaseSignal) Type() string { return signal.FType } -func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal { - var source_id NodeID = NodeID{} - if source != nil { - source_id = source.ID() - } - +func NewBaseSignal(_type string, direction SignalDirection) BaseSignal { signal := BaseSignal{ FDirection: direction, - FSource: source_id, FType: _type, } return signal } -func NewDownSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Down) +func NewDownSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Down) +} + +func NewSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Up) } -func NewSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Up) +func NewDirectSignal(_type string) BaseSignal { + return NewBaseSignal(_type, Direct) } -func NewDirectSignal(source Node, _type string) BaseSignal { - return NewBaseSignal(source, _type, Direct) +var AbortSignal = NewBaseSignal("abort", Down) +var CancelSignal = NewBaseSignal("cancel", Down) + +type IDSignal struct { + BaseSignal + ID NodeID `json:"id"` } -func AbortSignal(source Node) BaseSignal { - return NewBaseSignal(source, "abort", Down) +func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { + return IDSignal{ + BaseSignal: NewBaseSignal(_type, direction), + ID: id, + } } -func CancelSignal(source Node) BaseSignal { - return NewBaseSignal(source, "cancel", Down) +func NewStatusSignal(_type string, source NodeID) IDSignal { + return NewIDSignal(_type, Up, source) } type StartChildSignal struct { - BaseSignal - ChildID NodeID `json:"child_id"` + IDSignal Action string `json:"action"` } -func NewStartChildSignal(source Node, child_id NodeID, action string) StartChildSignal { +func NewStartChildSignal(child_id NodeID, action string) StartChildSignal { return StartChildSignal{ - BaseSignal: NewBaseSignal(source, "start_child", Direct), - ChildID: child_id, + IDSignal: NewIDSignal("start_child", Direct, child_id), Action: action, } } diff --git a/thread.go b/thread.go index 0e2630d..0de353e 100644 --- a/thread.go +++ b/thread.go @@ -10,32 +10,48 @@ import ( "github.com/google/uuid" ) -// SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel -func (thread * SimpleThread) Signal(ctx * Context, signal GraphSignal, nodes NodeMap) error { - err := thread.SimpleLockable.Signal(ctx, signal, nodes) +// Assumed that thread is already locked for signal +func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) error { + err := thread.SimpleLockable.Signal(context, signal) if err != nil { return err } - if signal.Direction() == Up { - // Child->Parent, thread updates parent and connected requirement - if thread.parent != nil { - UseMoreStates(ctx, []Node{thread.parent}, nodes, func(nodes NodeMap) error { - thread.parent.Signal(ctx, signal, nodes) + + switch signal.Direction() { + case Up: + err = UseMoreStates(ctx, locked, thread, NewLockMap( + NewLockRequest(thread, []string{"parent"}), + ), func(context *ReadContext) error { + if thread.parent != nil { + return UseMoreStates(ctx, locked, thread, NewLockRequest(thread.parent, []string{"signal"}), func(context *ReadContext) error { + return thread.parent.Signal(context, signal) + }) + } else { return nil - }) - } - } else if signal.Direction() == Down { - // Parent->Child, updates children and dependencies - UseMoreStates(ctx, NodeList(thread.children), nodes, func(nodes NodeMap) error { + } + }) + case Down: + err = UseMoreStates(ctx, locked, thread, NewLockMap( + NewLockRequest(thread, []string{"children"}), + RequestList(thread.childre, []string{"signal"}), + ), func(context *ReadContext) error { for _, child := range(thread.children) { - child.Signal(ctx, signal, nodes) + err := child.Signal(context, signal) + if err != nil { + return err + } } return nil }) - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + case Direct: + err = nil + default: + return fmt.Errorf("Invalid signal direction %d", signal.Direction()) + } + if err != nil { + return err } + thread.signal <- signal return nil } @@ -156,14 +172,16 @@ func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error { return nil } -func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool { +func checkIfChild(context *WriteContext, target Thread, cur Thread) bool { for _, child := range(cur.Children()) { if child.ID() == target.ID() { return true } is_child := false - UpdateMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { - is_child = checkIfChild(ctx, target, child, nodes) + UpdateMoreStates(ctx, locked, cur, NewLockMap( + NewLockRequest(child, []string{"children"}), + ), func(locked NodeLockMap) error { + is_child = checkIfChild(context, target, child) return nil }) if is_child { @@ -174,8 +192,7 @@ func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool return false } -// Requires thread and childs thread to be locked for write -func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, nodes NodeMap) error { +func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { if ctx == nil || thread == nil || child == nil { return fmt.Errorf("invalid input") } @@ -184,29 +201,34 @@ func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, no return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) } - if child.Parent() != nil { - return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) - } + return UpdateMoreStates(context, princ, NewNodeMap( + NewLockInfo(child, []string{"parent", "children"}), + NewLockInfo(thread, []string{"parent", "children"}), + ), func(context *WriteContext) { + if child.Parent() != nil { + return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) + } - if checkIfChild(ctx, thread, child, nodes) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) - } + if checkIfChild(context, thread, child) == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) + } - if checkIfChild(ctx, child, thread, nodes) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) - } + if checkIfChild(context, child, thread) == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) + } - err := thread.AddChild(child, info) - if err != nil { - return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) - } - child.SetParent(thread) + err := thread.AddChild(child, info) + if err != nil { + return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) + } + child.SetParent(thread) - if err != nil { - return err - } + if err != nil { + return err + } - return nil + return nil + }) } type ThreadAction func(* Context, Thread)(string, error) @@ -432,8 +454,8 @@ func NewSimpleThread(id NodeID, name string, state_name string, info_type reflec } } -// Requires that thread is already locked for read in UseStates -func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread { +// Requires the read permission of threads children +func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Thread { if thread == nil { panic("cannot recurse through nil") } @@ -443,8 +465,8 @@ func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread { for _, child := range thread.Children() { var result Thread = nil - UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { - result = FindChild(ctx, child, id, nodes) + UseMoreStates(context, princ, NewLockRequest(child, []string{"children"}), func(locked NodeLockMap) error { + result = FindChild(ctx, princ, child, id, locked) return nil }) if result != nil { @@ -499,12 +521,12 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error { return err } - err = UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + err = UpdateStates(ctx, thread, NewLockRequest(thread, []string{"state"}), func(locked NodeLockMap) error { err := thread.SetState("finished") if err != nil { return err } - return UnlockLockables(ctx, []Lockable{thread}, thread, nodes) + return UnlockLockables(ctx, []Lockable{thread}, thread, locked) }) if err != nil { @@ -563,23 +585,25 @@ func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lock } func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error { - return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { - child := thread.Child(signal.ChildID) + return UpdateStates(ctx, thread, NewLockRequest(thread, []string{"children"}), func(locked NodeLockMap) error { + child := thread.Child(signal.ID) if child == nil { - return fmt.Errorf("%s is not a child of %s", signal.ChildID, thread.ID()) + return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) } + return UpdateMoreStates(ctx, locked, thread, NewLockRequest(child, []string{"start"}), func(locked NodeLockMap) error { - info := thread.ChildInfo(signal.ChildID).(*ParentThreadInfo) - info.Start = true - ChildGo(ctx, thread, child, signal.Action) + info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) + info.Start = true + ChildGo(ctx, thread, child, signal.Action) - return nil + return nil + }) }) } func ThreadRestore(ctx * Context, thread Thread) { - UpdateStates(ctx, []Node{thread}, func(nodes NodeMap)(error) { - return UpdateMoreStates(ctx, NodeList(thread.Children()), nodes, func(nodes NodeMap) error { + UpdateStates(ctx, thread, NewLockRequest(thread, []string{"children"}), func(locked NodeLockMap)(error) { + return UpdateMoreStates(ctx, locked, thread, RequestList(thread.Children(), []string{"start"}), func(locked NodeLockMap) error { for _, child := range(thread.Children()) { should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) @@ -594,13 +618,13 @@ func ThreadRestore(ctx * Context, thread Thread) { } func ThreadStart(ctx * Context, thread Thread) error { - return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + return UpdateStates(ctx, thread, NewLockRequest(thread, []string{"start", "lock"}), func(locked NodeLockMap) error { owner_id := NodeID{} if thread.Owner() != nil { owner_id = thread.Owner().ID() } if owner_id != thread.ID() { - err := LockLockables(ctx, []Lockable{thread}, thread, nodes) + err := LockLockables(ctx, []Lockable{thread}, thread, locked) if err != nil { return err } @@ -628,11 +652,7 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { for { select { case signal := <- thread.SignalChannel(): - if signal.Source() == thread.ID() { - ctx.Log.Logf("thread", "THREAD_SIGNAL_INTERNAL") - } else { - ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) - } + ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) signal_fn, exists := thread.Handler(signal.Type()) if exists == true { ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) @@ -642,7 +662,7 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } case <- thread.Timeout(): timeout_action := "" - err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + err := UpdateStates(ctx, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *WriteContext) error { timeout_action = thread.TimeoutAction() thread.ClearTimeout() return nil @@ -656,36 +676,25 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } } -type ThreadAbortedError NodeID - -func (e ThreadAbortedError) Is(target error) bool { - error_type := reflect.TypeOf(ThreadAbortedError(NodeID{})) - target_type := reflect.TypeOf(target) - return error_type == target_type -} -func (e ThreadAbortedError) Error() string { - return fmt.Sprintf("Aborted by %s", (uuid.UUID)(e).String()) -} -func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { - return ThreadAbortedError(aborter) -} +var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread abort is to return a ThreadAbortedError func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - thread.Signal(ctx, NewSignal(thread, "thread_aborted"), nodes) - return nil + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, NewStatusSignal("aborted", thread.ID()), locked) }) - return "", NewThreadAbortedError(signal.Source()) + if err != nil { + return "", err + } + return "", ThreadAbortedError } // Default thread cancel is to finish the thread func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { - thread.Signal(ctx, NewSignal(thread, "thread_cancelled"), nodes) - return nil + err := UseStates(ctx, thread, NewLockRequest(thread, []string{"signal"}), func(locked NodeLockMap) error { + return thread.Signal(ctx, NewSignal("cancelled"), locked) }) - return "", nil + return "", err } func NewThreadActions() ThreadActions{ diff --git a/thread_test.go b/thread_test.go deleted file mode 100644 index 0e96c0a..0000000 --- a/thread_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package graphvent - -import ( - "testing" - "time" - "fmt" -) - -func TestNewThread(t * testing.T) { - ctx := testContext(t) - - t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - go func(thread Thread) { - time.Sleep(10*time.Millisecond) - UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - }(t1) - - err := ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) (error) { - owner := t1.owner - if owner != nil { - return fmt.Errorf("Wrong owner %+v", owner) - } - return nil - }) -} - -func TestThreadWithRequirement(t * testing.T) { - ctx := testContext(t) - - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - err := UpdateStates(ctx, []Node{l1, t1}, func(nodes NodeMap) error { - return LinkLockables(ctx, t1, []Lockable{l1}, nodes) - }) - fatalErr(t, err) - - go func (thread Thread) { - time.Sleep(10*time.Millisecond) - UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - }(t1) - fatalErr(t, err) - - err = ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner := l1.owner - if owner != nil { - return fmt.Errorf("Wrong owner %+v", owner) - } - return nil - }) - fatalErr(t, err) -} - -func TestThreadDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - - err := UpdateStates(ctx, []Node{t1, l1}, func(nodes NodeMap) error { - return LinkLockables(ctx, t1, []Lockable{l1}, nodes) - }) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - return t1.Signal(ctx, CancelSignal(nil), nodes) - }) - fatalErr(t, err) - - err = ThreadLoop(ctx, t1, "start") - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { - ser, err := t1.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) - - t1_loaded, err := LoadNode(ctx, t1.ID()) - fatalErr(t, err) - - err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error { - ser, err := t1_loaded.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) - return err - }) -} - -func TestThreadUnlink(t * testing.T) { - ctx := logTestContext(t, []string{}) - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r - t2_r := NewSimpleThread(RandID(), "Test Thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) - t2 := &t2_r - - - err := UpdateStates(ctx, []Node{t1, t2}, func(nodes NodeMap) error { - err := LinkThreads(ctx, t1, t2, nil, nodes) - if err != nil { - return err - } - - return UnlinkThreads(ctx, t1, t2) - }) - fatalErr(t, err) -} -