From 3ad969a5cac2af324ab073ef0f2b1d96abb0a609 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Thu, 27 Jul 2023 15:27:14 -0600 Subject: [PATCH] Switched from thread being the callback engine to node being the callback engine --- context.go | 13 +- gql.go | 137 +++++---------- gql_interfaces.go | 21 --- gql_mutation.go | 43 +---- gql_query.go | 2 +- gql_resolvers.go | 200 +--------------------- gql_subscribe.go | 2 +- gql_test.go | 87 ++-------- gql_types.go | 28 +-- graph_test.go | 1 + lockable.go | 428 ++++------------------------------------------ node.go | 414 ++++++++++++++++++++------------------------ node_test.go | 10 +- policy.go | 89 +--------- signal.go | 4 +- user.go | 8 +- 16 files changed, 316 insertions(+), 1171 deletions(-) diff --git a/context.go b/context.go index 09117ad..cb68c37 100644 --- a/context.go +++ b/context.go @@ -135,18 +135,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - thread_ctx := NewThreadExtContext() - err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, thread_ctx) - if err != nil { - return nil, err - } - - err = thread_ctx.RegisterThreadType(GQLThreadType, gql_actions, gql_handlers) - if err != nil { - return nil, err - } - - err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType}) + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) if err != nil { return nil, err } diff --git a/gql.go b/gql.go index a2ee777..b338cd5 100644 --- a/gql.go +++ b/gql.go @@ -30,7 +30,6 @@ import ( "encoding/pem" ) -const GQLThreadType = ThreadType("GQL") const GQLNodeType = NodeType("GQL") type AuthReqJSON struct { @@ -760,7 +759,7 @@ func NewGQLExtContext() *GQLExtContext { Fields: graphql.Fields{}, }) - mutation.AddFieldConfig("abort", GQLMutationAbort) + mutation.AddFieldConfig("stop", GQLMutationStop) mutation.AddFieldConfig("startChild", GQLMutationStartChild) subscription := graphql.NewObject(graphql.ObjectConfig{ @@ -790,10 +789,6 @@ func NewGQLExtContext() *GQLExtContext { if err != nil { panic(err) } - err = context.AddInterface(GQLInterfaceThread) - if err != nil { - panic(err) - } schema, err := BuildSchema(&context) if err != nil { @@ -829,7 +824,7 @@ func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal { return new_listener } -func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) error { +func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal Signal) { ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() @@ -846,7 +841,7 @@ func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) err } } ext.SubscribeListeners = active_listeners - return nil + return } const GQLExtType = ExtType("gql_thread") @@ -963,104 +958,58 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_ } } -var gql_actions ThreadActions = ThreadActions{ - "wait": ThreadWait, - "restore": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) { - return "start_server", ThreadRestore(ctx, thread, thread_ext, false) - }, - "start": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { - _, err := ThreadStart(ctx, thread, thread_ext) - if err != nil { - return "", err - } - return "start_server", ThreadRestore(ctx, thread, thread_ext, true) - }, - "start_server": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { - gql_ext, err := GetExt[*GQLExt](thread) - if err != nil { - return "", err - } - - mux := http.NewServeMux() - mux.HandleFunc("/auth", AuthHandler(ctx, thread, gql_ext)) - mux.HandleFunc("/gql", GQLHandler(ctx, thread, gql_ext)) - mux.HandleFunc("/gqlws", GQLWSHandler(ctx, thread, gql_ext)) +func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error { + mux := http.NewServeMux() + mux.HandleFunc("/auth", AuthHandler(ctx, node, gql_ext)) + mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext)) + mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext)) - // Server a graphiql interface(TODO make configurable whether to start this) - mux.HandleFunc("/graphiql", GraphiQLHandler()) + // Server a graphiql interface(TODO make configurable whether to start this) + mux.HandleFunc("/graphiql", GraphiQLHandler()) - // Server the ./site directory to /site (TODO make configurable with better defaults) - fs := http.FileServer(http.Dir("./site")) - mux.Handle("/site/", http.StripPrefix("/site", fs)) - - http_server := &http.Server{ - Addr: gql_ext.Listen, - Handler: mux, - } - - l, err := net.Listen("tcp", http_server.Addr) - if err != nil { - return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) - } - - cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key) - if err != nil { - return "", err - } + // Server the ./site directory to /site (TODO make configurable with better defaults) + fs := http.FileServer(http.Dir("./site")) + mux.Handle("/site/", http.StripPrefix("/site", fs)) - config := tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{"http/1.1"}, - } - - listener := tls.NewListener(l, &config) + http_server := &http.Server{ + Addr: gql_ext.Listen, + Handler: mux, + } - gql_ext.http_done.Add(1) - go func(qql_ext *GQLExt) { - defer gql_ext.http_done.Done() + l, err := net.Listen("tcp", http_server.Addr) + if err != nil { + return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) + } - err := http_server.Serve(listener) - if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) - } - }(gql_ext) + cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key) + if err != nil { + return err + } + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"http/1.1"}, + } - context := NewWriteContext(ctx) - err = UpdateStates(context, thread, NewACLInfo(thread, []string{"http_server"}), func(context *StateContext) error { - gql_ext.tcp_listener = listener - gql_ext.http_server = http_server - return nil - }) + listener := tls.NewListener(l, &config) - if err != nil { - return "", err - } + gql_ext.http_done.Add(1) + go func(qql_ext *GQLExt) { + defer gql_ext.http_done.Done() - context = NewReadContext(ctx) - err = thread.Process(context, thread.ID, NewStatusSignal("server_started", thread.ID)) - if err != nil { - return "", err + err := http_server.Serve(listener) + if err != http.ErrServerClosed { + panic(fmt.Sprintf("Failed to start gql server: %s", err)) } + }(gql_ext) - return "wait", nil - }, - "finish": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) { - gql_ext, err := GetExt[*GQLExt](thread) - if err != nil { - return "", err - } - gql_ext.http_server.Shutdown(context.TODO()) - gql_ext.http_done.Wait() - return ThreadFinish(ctx, thread, thread_ext) - }, + gql_ext.tcp_listener = listener + gql_ext.http_server = http_server + return nil } -var gql_handlers ThreadHandlers = ThreadHandlers{ - "child_linked": ThreadChildLinked, - "start_child": ThreadStartChild, - "abort": ThreadAbort, - "stop": ThreadStop, +func StopGQLServer(gql_ext *GQLExt) { + gql_ext.http_server.Shutdown(context.TODO()) + gql_ext.http_done.Wait() } - diff --git a/gql_interfaces.go b/gql_interfaces.go index 9bd9168..38dffdf 100644 --- a/gql_interfaces.go +++ b/gql_interfaces.go @@ -55,22 +55,6 @@ func addLockableInterfaceFields(gql *GQLInterface, gql_lockable *GQLInterface) { }) } -func AddThreadInterfaceFields(gql *GQLInterface) { - addThreadInterfaceFields(gql, GQLInterfaceThread) -} - -func addThreadInterfaceFields(gql *GQLInterface, gql_thread *GQLInterface) { - AddNodeInterfaceFields(gql) - - gql.Interface.AddFieldConfig("Children", &graphql.Field{ - Type: gql_thread.List, - }) - - gql.Interface.AddFieldConfig("Parent", &graphql.Field{ - Type: gql_thread.Interface, - }) -} - func NodeHasExtensions(node *Node, extensions []ExtType) bool { if node == nil { return false @@ -136,8 +120,3 @@ var GQLInterfaceLockable = NewGQLInterface("Lockable", "DefaultLockable", []*gra addLockableFields(gql.Default, gql.Interface, gql.List) }) -var GQLInterfaceThread = NewGQLInterface("Thread", "DefaultThread", []*graphql.Interface{GQLInterfaceNode.Interface, }, []ExtType{ThreadExtType, LockableExtType}, func(gql *GQLInterface){ - addThreadInterfaceFields(gql, gql) -}, func(gql *GQLInterface) { - addThreadFields(gql.Default, gql.Interface, gql.List) -}) diff --git a/gql_mutation.go b/gql_mutation.go index d961604..f30b010 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -1,11 +1,10 @@ package graphvent import ( - "fmt" "github.com/graphql-go/graphql" ) -var GQLMutationAbort = NewField(func()*graphql.Field { - gql_mutation_abort := &graphql.Field{ +var GQLMutationStop = NewField(func()*graphql.Field { + gql_mutation_stop := &graphql.Field{ Type: GQLTypeSignal.Type, Args: graphql.FieldConfigArgument{ "id": &graphql.ArgumentConfig{ @@ -13,39 +12,11 @@ var GQLMutationAbort = NewField(func()*graphql.Field { }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - id, err := ExtractID(p, "id") - if err != nil { - return nil, err - } - - var node *Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLMap( - NewACLInfo(ctx.Server, []string{"children"}), - ), func(context *StateContext) (error){ - node, err = FindChild(context, ctx.User, ctx.Server, id) - if err != nil { - return err - } - if node == nil { - return fmt.Errorf("Failed to find ID: %s as child of server thread", id) - } - return node.Process(context, ctx.User.ID, AbortSignal) - }) - if err != nil { - return nil, err - } - - return AbortSignal, nil + return StopSignal, nil }, } - return gql_mutation_abort + return gql_mutation_stop }) var GQLMutationStartChild = NewField(func()*graphql.Field{ @@ -64,7 +35,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, ctx, err := PrepResolve(p) + /*_, ctx, err := PrepResolve(p) if err != nil { return nil, err } @@ -102,10 +73,10 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ }) if err != nil { return nil, err - } + }*/ // TODO: wait for the result of the signal to send back instead of just the signal - return signal, nil + return nil, nil }, } diff --git a/gql_query.go b/gql_query.go index 0531e6b..e9b7667 100644 --- a/gql_query.go +++ b/gql_query.go @@ -4,7 +4,7 @@ import ( ) var GQLQuerySelf = &graphql.Field{ - Type: GQLInterfaceThread.Default, + Type: GQLInterfaceNode.Default, Resolve: func(p graphql.ResolveParams) (interface{}, error) { _, ctx, err := PrepResolve(p) if err != nil { diff --git a/gql_resolvers.go b/gql_resolvers.go index 2c8322a..aad4e43 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -89,218 +89,36 @@ func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { } func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - gql_ext, err := GetExt[*GQLExt](node) - if err != nil { - return nil, err - } - - listen := "" - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error { - listen = gql_ext.Listen - return nil - }) - - if err != nil { - return nil, err - } - - return listen, nil + // TODO figure out how nodes can read eachother + return "", nil } func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - thread_ext, err := GetExt[*ThreadExt](node) - if err != nil { - return nil, err - } - - var parent *Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { - parent = thread_ext.Parent - return nil - }) - - if err != nil { - return nil, err - } - - return parent, nil + return nil, nil } func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - thread_ext, err := GetExt[*ThreadExt](node) - if err != nil { - return nil, err - } - - var state string - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error { - state = thread_ext.State - return nil - }) - - if err != nil { - return nil, err - } - - return state, nil + return "", nil } func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - thread_ext, err := GetExt[*ThreadExt](node) - if err != nil { - return nil, err - } - - var children []*Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { - children = thread_ext.ChildList() - return nil - }) - - if err != nil { - return nil, err - } - - return children, nil + return nil, nil } func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - lockable_ext, err := GetExt[*LockableExt](node) - if err != nil { - return nil, err - } - - var requirements []*Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { - requirements = make([]*Node, len(lockable_ext.Requirements)) - i := 0 - for _, req := range(lockable_ext.Requirements) { - requirements[i] = req - i += 1 - } - return nil - }) - - if err != nil { - return nil, err - } - - return requirements, nil + return nil, nil } func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - lockable_ext, err := GetExt[*LockableExt](node) - if err != nil { - return nil, err - } - - var dependencies []*Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error { - dependencies = make([]*Node, len(lockable_ext.Dependencies)) - i := 0 - for _, dep := range(lockable_ext.Dependencies) { - dependencies[i] = dep - i += 1 - } - return nil - }) - - if err != nil { - return nil, err - } - - return dependencies, nil + return nil, nil } func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - lockable_ext, err := GetExt[*LockableExt](node) - if err != nil { - return nil, err - } - - var owner *Node = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error { - owner = lockable_ext.Owner - return nil - }) - - if err != nil { - return nil, err - } - - return owner, nil + return nil, nil } func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) { - node, ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - group_ext, err := GetExt[*GroupExt](node) - if err != nil { - return nil, err - } - - var members []*Node - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error { - members = make([]*Node, len(group_ext.Members)) - i := 0 - for _, member := range(group_ext.Members) { - members[i] = member - i += 1 - } - return nil - }) - - if err != nil { - return nil, err - } - - return members, nil + return nil, nil } func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { diff --git a/gql_subscribe.go b/gql_subscribe.go index a211df5..63440d6 100644 --- a/gql_subscribe.go +++ b/gql_subscribe.go @@ -46,7 +46,7 @@ func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *N var GQLSubscriptionSelf = NewField(func()*graphql.Field{ gql_subscription_self := &graphql.Field{ - Type: GQLInterfaceThread.Default, + Type: GQLInterfaceNode.Default, Resolve: func(p graphql.ResolveParams) (interface{}, error) { return p.Source, nil }, diff --git a/gql_test.go b/gql_test.go index 71df628..19602d3 100644 --- a/gql_test.go +++ b/gql_test.go @@ -3,114 +3,51 @@ package graphvent import ( "testing" "time" - "errors" "crypto/rand" "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" ) -func TestGQL(t *testing.T) { - - -} - func TestGQLDB(t * testing.T) { - ctx := logTestContext(t, []string{"thread", "test", "signal", "policy", "db"}) + ctx := logTestContext(t, []string{"loop", "node", "thread", "test", "signal", "policy", "db"}) TestUserNodeType := NodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) fatalErr(t, err) - u1 := NewNode(ctx, RandID(), TestUserNodeType) + u1 := NewNode(ctx, RandID(), TestUserNodeType, nil) ctx.Log.Logf("test", "U1_ID: %s", u1.ID) - TestThreadNodeType := NodeType("TEST_THREAD") - err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType}) - fatalErr(t, err) - - t1_p1 := NewParentOfPolicy(Actions{"signal.abort", "signal.stop", "state.write"}) - t1_p2 := NewPerNodePolicy(NodeActions{ - u1.ID: Actions{"parent.write"}, - }) - t1_thread, err := NewThreadExt(ctx, BaseThreadType, nil,nil, "init", nil) - fatalErr(t, err) - t1 := NewNode(ctx, - RandID(), - TestThreadNodeType, - NewACLExt(&t1_p1, &t1_p2), - t1_thread) - ctx.Log.Logf("test", "T1_ID: %s", t1.ID) - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql_p1 := NewChildOfPolicy(Actions{"signal.status"}) - gql_p2 := NewPerNodePolicy(NodeActions{ - u1.ID: Actions{"children.write", "dependencies.write"}, - }) - - gql_thread, err := NewThreadExt(ctx, GQLThreadType, nil, nil, "init", nil) - fatalErr(t, err) - gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil) listener_ext := NewListenerExt(10) - gql := NewNode(ctx, RandID(), GQLNodeType, - gql_thread, + gql := NewNode(ctx, RandID(), GQLNodeType, nil, gql_ext, listener_ext, - NewACLExt(&gql_p1, &gql_p2), + NewACLExt(), NewGroupExt(nil)) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - info := ParentInfo{true, "start", "start"} - context := NewWriteContext(ctx) - err = UpdateStates(context, u1, ACLMap{}, func(context *StateContext) error { - return LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{ - ParentInfoType: &info, - }}) - }) - fatalErr(t, err) - - context = NewReadContext(ctx) - err = gql.Process(context, gql.ID, NewStatusSignal("child_linked", t1.ID)) - fatalErr(t, err) - context = NewReadContext(ctx) - err = gql.Process(context, gql.ID, AbortSignal) + err = gql.Signal(ctx, gql.ID, StopSignal) fatalErr(t, err) - err = ThreadLoop(ctx, gql, "start") - if errors.Is(err, ThreadAbortedError) == false { - fatalErr(t, err) - } + (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener") - (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") - - context = NewReadContext(ctx) - err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error { - ser1, err := gql.Serialize() - ser2, err := u1.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser1) - ctx.Log.Logf("test", "\n%s\n\n", ser2) - return err - }) + ser1, err := gql.Serialize() + ser2, err := u1.Serialize() + ctx.Log.Logf("test", "\n%s\n\n", ser1) + ctx.Log.Logf("test", "\n%s\n\n", ser2) // Clear all loaded nodes from the context so it loads them from the database ctx.Nodes = NodeMap{} gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) - context = NewReadContext(ctx) - err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error { - var err error - listener_ext, err = GetExt[*ListenerExt](gql_loaded) - if err != nil { - return err - } - return gql_loaded.Process(context, gql_loaded.ID, StopSignal) - }) + listener_ext, err = GetExt[*ListenerExt](gql_loaded) fatalErr(t, err) - - err = ThreadLoop(ctx, gql_loaded, "restore") + err = gql_loaded.Signal(ctx, gql_loaded.ID, StopSignal) fatalErr(t, err) (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") diff --git a/gql_types.go b/gql_types.go index 529bb7d..3b4fc96 100644 --- a/gql_types.go +++ b/gql_types.go @@ -38,35 +38,11 @@ func addLockableFields(object *graphql.Object, lockable_interface *graphql.Inter }) } -func AddThreadFields(object *graphql.Object) { - addThreadFields(object, GQLInterfaceThread.Interface, GQLInterfaceThread.List) -} - -func addThreadFields(object *graphql.Object, thread_interface *graphql.Interface, thread_list *graphql.List) { - AddNodeFields(object) - - object.AddFieldConfig("State", &graphql.Field{ - Type: graphql.String, - Resolve: GQLThreadState, - }) - - object.AddFieldConfig("Children", &graphql.Field{ - Type: thread_list, - Resolve: GQLThreadChildren, - }) - - object.AddFieldConfig("Parent", &graphql.Field{ - Type: thread_interface, - Resolve: GQLThreadParent, - }) -} - var GQLNodeInterfaces = []*graphql.Interface{GQLInterfaceNode.Interface} var GQLLockableInterfaces = append(GQLNodeInterfaces, GQLInterfaceLockable.Interface) -var GQLThreadInterfaces = append(GQLNodeInterfaces, GQLInterfaceThread.Interface) -var GQLTypeGQLNode = NewGQLNodeType(GQLNodeType, GQLThreadInterfaces, func(gql *GQLType) { - AddThreadFields(gql.Type) +var GQLTypeGQLNode = NewGQLNodeType(GQLNodeType, GQLNodeInterfaces, func(gql *GQLType) { + AddNodeFields(gql.Type) gql.Type.AddFieldConfig("Listen", &graphql.Field{ Type: graphql.String, diff --git a/graph_test.go b/graph_test.go index c0ce083..d7a8641 100644 --- a/graph_test.go +++ b/graph_test.go @@ -60,6 +60,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { listener := NewNode(ctx, RandID(), SimpleListenerNodeType, + nil, listener_extension, NewACLExt(&policy), NewLockableExt(nil, nil, nil, nil)) diff --git a/lockable.go b/lockable.go index 02fa8d0..d046b59 100644 --- a/lockable.go +++ b/lockable.go @@ -1,7 +1,6 @@ package graphvent import ( - "fmt" "encoding/json" ) @@ -32,14 +31,14 @@ func (listener *ListenerExt) Type() ExtType { return ListenerExtType } -func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error { - context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal) +func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { + ctx.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal) select { case ext.Chan <- signal: default: - return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal) + ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) } - return nil + return } func (ext *ListenerExt) Serialize() ([]byte, error) { @@ -47,10 +46,10 @@ func (ext *ListenerExt) Serialize() ([]byte, error) { } type LockableExt struct { - Owner *Node - Requirements map[NodeID]*Node - Dependencies map[NodeID]*Node - LocksHeld map[NodeID]*Node + Owner *NodeID `json:"owner"` + Requirements []NodeID `json:"requirements"` + Dependencies []NodeID `json:"dependencies"` + LocksHeld map[NodeID]*NodeID `json:"locks_held"` } const LockableExtType = ExtType("LOCKABLE") @@ -58,33 +57,13 @@ func (ext *LockableExt) Type() ExtType { return LockableExtType } -type LockableExtJSON struct { - Owner string `json:"owner"` - Requirements []string `json:"requirements"` - Dependencies []string `json:"dependencies"` - LocksHeld map[string]string `json:"locks_held"` -} - func (ext *LockableExt) Serialize() ([]byte, error) { - return json.MarshalIndent(&LockableExtJSON{ - Owner: SaveNode(ext.Owner), - Requirements: SaveNodeList(ext.Requirements), - Dependencies: SaveNodeList(ext.Dependencies), - LocksHeld: SaveNodeMap(ext.LocksHeld), - }, "", " ") + return json.MarshalIndent(ext, "", " ") } -func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt { - if requirements == nil { - requirements = NodeMap{} - } - - if dependencies == nil { - dependencies = NodeMap{} - } - +func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID, locks_held map[NodeID]*NodeID) *LockableExt { if locks_held == nil { - locks_held = NodeMap{} + locks_held = map[NodeID]*NodeID{} } return &LockableExt{ @@ -96,412 +75,71 @@ func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, loc } func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { - var j LockableExtJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j) - - owner, err := RestoreNode(ctx, j.Owner) + var ext LockableExt + err := json.Unmarshal(data, &ext) if err != nil { return nil, err } - requirements, err := RestoreNodeList(ctx, j.Requirements) - if err != nil { - return nil, err - } + ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext) - dependencies, err := RestoreNodeList(ctx, j.Dependencies) - if err != nil { - return nil, err - } - - locks_held, err := RestoreNodeMap(ctx, j.LocksHeld) - if err != nil { - return nil, err - } - - - return NewLockableExt(owner, requirements, dependencies, locks_held), nil + return &ext, nil } -func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error { - context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) +func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { + ctx.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) switch signal.Direction() { case Up: owner_sent := false for _, dependency := range(ext.Dependencies) { - context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) - err := dependency.Process(context, node.ID, signal) + err := node.Signal(ctx, dependency, signal) if err != nil { - return err + ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err) } + if ext.Owner != nil { - if dependency.ID == ext.Owner.ID { + if dependency == *ext.Owner { owner_sent = true } } } + if ext.Owner != nil && owner_sent == false { - if ext.Owner.ID != node.ID { - context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) - err := ext.Owner.Process(context, node.ID, signal) + if *ext.Owner != node.ID { + err := node.Signal(ctx, *ext.Owner, signal) if err != nil { - return err + ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, *ext.Owner, err) } } } case Down: for _, requirement := range(ext.Requirements) { - err := requirement.Process(context, node.ID, signal) + err := node.Signal(ctx, requirement, signal) if err != nil { - return err + ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err) } } case Direct: default: - return fmt.Errorf("invalid signal direction %d", signal.Direction()) } - return nil } -func (ext *LockableExt) RecordUnlock(node *Node) *Node { - last_owner, exists := ext.LocksHeld[node.ID] +func (ext *LockableExt) RecordUnlock(node NodeID) *NodeID { + last_owner, exists := ext.LocksHeld[node] if exists == false { panic("Attempted to take a get the original lock holder of a lockable we don't own") } - delete(ext.LocksHeld, node.ID) + delete(ext.LocksHeld, node) return last_owner } -func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) { - _, exists := ext.LocksHeld[node.ID] +func (ext *LockableExt) RecordLock(node NodeID, last_owner *NodeID) { + _, exists := ext.LocksHeld[node] if exists == true { panic("Attempted to lock a lockable we're already holding(lock cycle)") } - ext.LocksHeld[node.ID] = last_owner -} - -// Removes requirement as a requirement from lockable -func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error { - lockable_ext, err := GetExt[*LockableExt](lockable) - if err != nil { - return err - } - requirement_ext, err := GetExt[*LockableExt](requirement) - if err != nil { - return err - } - return UpdateStates(context, princ, ACLMap{ - lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}}, - requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}}, - }, func(context *StateContext) error { - var found *Node = nil - for _, req := range(lockable_ext.Requirements) { - if requirement.ID == req.ID { - found = req - break - } - } - - if found == nil { - return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID) - } - - delete(requirement_ext.Dependencies, lockable.ID) - delete(lockable_ext.Requirements, requirement.ID) - - return nil - }) -} - -// Link requirements as requirements to lockable -func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error { - if lockable == nil { - return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") - } - - if len(requirements) == 0 { - return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call") - } - - lockable_ext, err := GetExt[*LockableExt](lockable) - if err != nil { - return err - } - - req_exts := map[NodeID]*LockableExt{} - for _, requirement := range(requirements) { - if requirement == nil { - return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement") - } - - if lockable.ID == requirement.ID { - return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID) - } - - _, exists := req_exts[requirement.ID] - if exists == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID) - } - ext, err := GetExt[*LockableExt](requirement) - if err != nil { - return err - } - req_exts[requirement.ID] = ext - } - - return UpdateStates(context, princ, NewACLMap( - NewACLInfo(lockable, []string{"requirements"}), - ACLList(requirements, []string{"dependencies"}), - ), func(context *StateContext) error { - // Check that all the requirements can be added - // If the lockable is already locked, need to lock this resource as well before we can add it - for _, requirement := range(requirements) { - requirement_ext := req_exts[requirement.ID] - for _, req := range(requirements) { - if req.ID == requirement.ID { - continue - } - - is_req, err := checkIfRequirement(context, req.ID, requirement_ext) - if err != nil { - return err - } else if is_req { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID) - - } - } - - is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext) - if err != nil { - return err - } else if is_req { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID) - } - - is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext) - if err != nil { - return err - } else if is_req { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID) - } - - if lockable_ext.Owner == nil { - // If the new owner isn't locked, we can add the requirement - } else if requirement_ext.Owner == nil { - // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID) - } else { - // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement_ext.Owner.ID != lockable_ext.Owner.ID { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID) - } - } - } - // Update the states of the requirements - for _, requirement := range(requirements) { - requirement_ext := req_exts[requirement.ID] - requirement_ext.Dependencies[lockable.ID] = lockable - lockable_ext.Requirements[lockable.ID] = requirement - context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID) - } - - // Return no error - return nil - }) -} - -func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) { - for _, req := range(cur.Requirements) { - if req.ID == id { - return true, nil - } - - req_ext, err := GetExt[*LockableExt](req) - if err != nil { - return false, err - } - - var is_req bool - err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error { - is_req, err = checkIfRequirement(context, id, req_ext) - return err - }) - if err != nil { - return false, err - } - if is_req == true { - return true, nil - } - } - - return false, nil -} - -// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error -// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call -func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error { - if to_lock == nil { - return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided") - } - - req_exts := map[NodeID]*LockableExt{} - for _, l := range(to_lock) { - var err error - if l == nil { - return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") - } - - req_exts[l.ID], err = GetExt[*LockableExt](l) - if err != nil { - return err - } - } - - if new_owner == nil { - return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") - } - - new_owner_ext, err := GetExt[*LockableExt](new_owner) - if err != nil { - return err - } - - // Called with no requirements to lock, success - if len(to_lock) == 0 { - return nil - } - - return UpdateStates(context, new_owner, NewACLMap( - ACLListM(to_lock, []string{"lock"}), - NewACLInfo(new_owner, nil), - ), func(context *StateContext) error { - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_lock) { - req_ext := req_exts[req.ID] - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID) - - // If req is alreay locked, check that we can pass the lock - if req_ext.Owner != nil { - owner := req_ext.Owner - if owner.ID == new_owner.ID { - continue - } else { - err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ - return LockLockables(context, req_ext.Requirements, req) - }) - if err != nil { - return err - } - } - } else { - err := LockLockables(context, req_ext.Requirements, req) - if err != nil { - return err - } - } - } - - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_lock) { - req_ext := req_exts[req.ID] - old_owner := req_ext.Owner - // If the lockable was previously unowned, update the state - if old_owner == nil { - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID) - req_ext.Owner = new_owner - new_owner_ext.RecordLock(req, old_owner) - // Otherwise if the new owner already owns it, no need to update state - } else if old_owner.ID == new_owner.ID { - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID) - // Otherwise update the state - } else { - req_ext.Owner = new_owner - new_owner_ext.RecordLock(req, old_owner) - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID) - } - } - return nil - }) - -} - -func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error { - if to_unlock == nil { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") - } - - req_exts := map[NodeID]*LockableExt{} - for _, l := range(to_unlock) { - if l == nil { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil") - } - - var err error - req_exts[l.ID], err = GetExt[*LockableExt](l) - if err != nil { - return err - } - } - - if old_owner == nil { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") - } - - old_owner_ext, err := GetExt[*LockableExt](old_owner) - if err != nil { - return err - } - - - // Called with no requirements to unlock, success - if len(to_unlock) == 0 { - return nil - } - - return UpdateStates(context, old_owner, NewACLMap( - ACLListM(to_unlock, []string{"lock"}), - NewACLInfo(old_owner, nil), - ), func(context *StateContext) error { - // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_unlock) { - req_ext := req_exts[req.ID] - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID) - - // Check if the owner is correct - if req_ext.Owner != nil { - if req_ext.Owner.ID != old_owner.ID { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID) - } - } else { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID) - } - - err := UnlockLockables(context, req_ext.Requirements, req) - if err != nil { - return err - } - } - - // At this point state modification will be started, so no errors can be returned - for _, req := range(to_unlock) { - req_ext := req_exts[req.ID] - new_owner := old_owner_ext.RecordUnlock(req) - req_ext.Owner = new_owner - if new_owner == nil { - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID) - } else { - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID) - } - } - - return nil - }) + ext.LocksHeld[node] = last_owner } func SaveNode(node *Node) string { diff --git a/node.go b/node.go index 97503b8..36365db 100644 --- a/node.go +++ b/node.go @@ -2,9 +2,11 @@ package graphvent import ( "sync" + "time" "reflect" "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" + "runtime" "fmt" "encoding/binary" "encoding/json" @@ -20,6 +22,17 @@ func (id NodeID) MarshalJSON() ([]byte, error) { return json.Marshal(&str) } +func (id *NodeID) UnmarshalJSON(bytes []byte) error { + var id_str string + err := json.Unmarshal(bytes, &id_str) + if err != nil { + return err + } + + *id, err = ParseID(id_str) + return err +} + var ZeroUUID = uuid.UUID{} var ZeroID = NodeID(ZeroUUID) @@ -62,20 +75,132 @@ type Serializable[I comparable] interface { Serialize() ([]byte, error) } -// NodeExtensions are additional data that can be attached to nodes, and used in node functions type Extension interface { Serializable[ExtType] - // Send a signal to this extension to process, - // this typically triggers signals to be sent to nodes linked in the extension - Process(context *StateContext, node *Node, signal Signal) error + Process(context *Context, source NodeID, node *Node, signal Signal) } +type QueuedSignal struct { + Signal Signal + Time time.Time +} + +const NODE_MSG_CHAN_DEFAULT = 1024 // Nodes represent an addressible group of extensions type Node struct { ID NodeID Type NodeType Lock sync.RWMutex Extensions map[ExtType]Extension + + MsgChan chan Msg + TimeoutChan <-chan time.Time + + LoopLock sync.Mutex + Active bool + + SignalQueue []QueuedSignal + NextSignal *QueuedSignal +} + +func (node *Node) QueueSignal(time time.Time, signal Signal) { + node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time}) + node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) +} + +func (node *Node) ClearSignalQueue() { + node.SignalQueue = []QueuedSignal{} + node.NextSignal = nil + node.TimeoutChan = nil +} + +func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) { + var soonest_signal *QueuedSignal + var soonest_time time.Time + for _, signal := range(signals) { + if signal.Time.Compare(soonest_time) == -1 || soonest_signal == nil { + soonest_signal = &signal + soonest_time = signal.Time + } + } + + if soonest_signal != nil { + return soonest_signal, time.After(time.Until(soonest_time)) + } else { + return nil, nil + } +} + +func RunNode(ctx *Context, node *Node) { + ctx.Log.Logf("node", "RUN_START: %s", node.ID) + err := NodeLoop(ctx, node) + if err != nil { + panic(err) + } + ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) +} + +type Msg struct { + Source NodeID + Signal Signal +} + +// Main Loop for Threads, starts a write context, so cannot be called from a write or read context +func NodeLoop(ctx *Context, node *Node) error { + node.LoopLock.Lock() + defer node.LoopLock.Unlock() + + node.Active = true + for true { + var signal Signal + var source NodeID + select { + case msg := <- node.MsgChan: + signal = msg.Signal + source = msg.Source + err := Allowed(ctx, msg.Source, string(signal.Type()), node) + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) + continue + } + case <-node.TimeoutChan: + signal = node.NextSignal.Signal + source = node.ID + node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) + ctx.Log.Logf("node", "NODE_TIMEOUT %s - NEXT_SIGNAL: %s", node.ID, signal) + } + + // Handle special signal types + if signal.Type() == StopSignalType { + node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) + break + } + node.Process(ctx, source, signal) + } + return nil +} + +func (node *Node) Process(ctx *Context, source NodeID, signal Signal) { + for ext_type, ext := range(node.Extensions) { + ctx.Log.Logf("signal", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) + ext.Process(ctx, source, node, signal) + } +} + +func (node *Node) Signal(ctx *Context, dest NodeID, signal Signal) error { + target, exists := ctx.Nodes[dest] + if exists == false { + return fmt.Errorf("%s does not exist, cannot signal it", dest) + } + select { + case target.MsgChan <- Msg{node.ID, signal}: + default: + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + stack_str := string(buf[:n]) + return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str) + } + return nil } func GetCtx[T Extension, C any](ctx *Context) (C, error) { @@ -118,8 +243,10 @@ func (node *Node) Serialize() ([]byte, error) { Magic: NODE_DB_MAGIC, TypeHash: node.Type.Hash(), NumExtensions: uint32(len(extensions)), + NumQueuedSignals: uint32(len(node.SignalQueue)), }, Extensions: extensions, + QueuedSignals: node.SignalQueue, } i := 0 @@ -141,7 +268,8 @@ func (node *Node) Serialize() ([]byte, error) { return node_db.Serialize(), nil } -func NewNode(ctx *Context, id NodeID, node_type NodeType, extensions ...Extension) *Node { +// Create a new node in memory and start it's event loop +func NewNode(ctx *Context, id NodeID, node_type NodeType, queued_signals []QueuedSignal, extensions ...Extension) *Node { _, exists := ctx.Nodes[id] if exists == true { panic("Attempted to create an existing node") @@ -168,18 +296,31 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, extensions ...Extensio } } + if queued_signals == nil { + queued_signals = []QueuedSignal{} + } + + next_signal, timeout_chan := SoonestSignal(queued_signals) + node := &Node{ ID: id, Type: node_type, Extensions: ext_map, + MsgChan: make(chan Msg, NODE_MSG_CHAN_DEFAULT), + TimeoutChan: timeout_chan, + SignalQueue: queued_signals, + NextSignal: next_signal, } ctx.Nodes[id] = node + WriteNode(ctx, node) + + go RunNode(ctx, node) return node } -func Allowed(context *StateContext, principal_id NodeID, action string, node *Node) error { - context.Graph.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) +func Allowed(ctx *Context, principal_id NodeID, action string, node *Node) error { + ctx.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension if principal_id == node.ID { return nil @@ -191,43 +332,24 @@ func Allowed(context *StateContext, principal_id NodeID, action string, node *No return err } - return policy_ext.Allows(context, principal_id, action, node) -} - -// Check that princ is allowed to signal this action, -// then send the signal to all the extensions of the node -func (node *Node) Process(context *StateContext, princ_id NodeID, signal Signal) error { - ser, _ := signal.Serialize() - context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, string(ser)) - - err := Allowed(context, princ_id, fmt.Sprintf("signal.%s", signal.Type()), node) - if err != nil { - return err - } - - for ext_type, ext := range(node.Extensions) { - err = ext.Process(context, node, signal) - if err != nil { - context.Graph.Log.Logf("signal", "EXTENSION_SIGNAL_ERR: %s/%s - %s", node.ID, ext_type, err) - } - } - - return nil + return policy_ext.Allows(ctx, principal_id, action, node) } // Magic first four bytes of serialized DB content, stored big endian const NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function -const NODE_DB_HEADER_LEN = 16 +const NODE_DB_HEADER_LEN = 20 // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node type NodeDBHeader struct { Magic uint32 NumExtensions uint32 + NumQueuedSignals uint32 TypeHash uint64 } type NodeDB struct { Header NodeDBHeader + QueuedSignals []QueuedSignal Extensions []ExtensionDB } @@ -239,7 +361,8 @@ func NewNodeDB(data []byte) (NodeDB, error) { magic := binary.BigEndian.Uint32(data[0:4]) num_extensions := binary.BigEndian.Uint32(data[4:8]) - node_type_hash := binary.BigEndian.Uint64(data[8:16]) + num_queued_signals := binary.BigEndian.Uint32(data[8:12]) + node_type_hash := binary.BigEndian.Uint64(data[12:20]) ptr += NODE_DB_HEADER_LEN @@ -269,13 +392,20 @@ func NewNodeDB(data []byte) (NodeDB, error) { ptr += int(EXTENSION_DB_HEADER_LEN + length) } + queued_signals := make([]QueuedSignal, num_queued_signals) + for i, _ := range(queued_signals) { + queued_signals[i] = QueuedSignal{} + } + return NodeDB{ Header: NodeDBHeader{ Magic: magic, TypeHash: node_type_hash, NumExtensions: num_extensions, + NumQueuedSignals: num_queued_signals, }, Extensions: extensions, + QueuedSignals: queued_signals, }, nil } @@ -287,7 +417,8 @@ func (header NodeDBHeader) Serialize() []byte { ret := make([]byte, NODE_DB_HEADER_LEN) binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) - binary.BigEndian.PutUint64(ret[8:16], header.TypeHash) + binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals) + binary.BigEndian.PutUint64(ret[12:20], header.TypeHash) return ret } @@ -324,6 +455,20 @@ type ExtensionDB struct { } // Write multiple nodes to the database in a single transaction +func WriteNode(ctx *Context, node *Node) error { + ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) + + bytes, err := node.Serialize() + if err != nil { + return err + } + + id_bytes := node.ID.Serialize() + + return ctx.DB.Update(func(txn *badger.Txn) error { + return txn.Set(id_bytes, bytes) + }) +} func WriteNodes(context *StateContext) error { err := ValidateStateContext(context, "write", true) if err != nil { @@ -368,10 +513,13 @@ func WriteNodes(context *StateContext) error { // Recursively load a node from the database. func LoadNode(ctx * Context, id NodeID) (*Node, error) { + ctx.Log.Logf("db", "LOOKING_FOR_NODE: %s", id) node, exists := ctx.Nodes[id] if exists == true { + ctx.Log.Logf("db", "NODE_ALREADY_LOADED: %s", id) return node,nil } + ctx.Log.Logf("db", "LOADING_NODE: %s", id) var bytes []byte err := ctx.DB.View(func(txn *badger.Txn) error { @@ -400,10 +548,15 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) } + next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals) node = &Node{ ID: id, Type: node_type.Type, Extensions: map[ExtType]Extension{}, + MsgChan: make(chan Msg, NODE_MSG_CHAN_DEFAULT), + TimeoutChan: timeout_chan, + SignalQueue: node_db.QueuedSignals, + NextSignal: next_signal, } ctx.Nodes[id] = node @@ -462,6 +615,9 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { } ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) + + go RunNode(ctx, node) + return node, nil } @@ -605,197 +761,3 @@ func del[K comparable](list []K, val K) []K { list[idx] = list[len(list)-1] return list[:len(list)-1] } - -// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read -// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read -func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error { - if principal == nil || new_nodes == nil || state_fn == nil { - return fmt.Errorf("nil passed to UseStates") - } - - err := ValidateStateContext(context, "read", false) - if err != nil { - return err - } - - if context.Started == false { - context.Started = true - } - - new_locks := []*Node{} - _, princ_locked := context.Locked[principal.ID] - if princ_locked == false { - new_locks = append(new_locks, principal) - context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", principal.ID.String()) - principal.Lock.RLock() - } - - princ_permissions, princ_exists := context.Permissions[principal.ID] - new_permissions := ACLMap{} - if princ_exists == true { - for id, info := range(princ_permissions) { - new_permissions[id] = info - } - } - - for _, request := range(new_nodes) { - node := request.Node - if node == nil { - return fmt.Errorf("node in request list is nil") - } - id := node.ID - - if id != principal.ID { - _, locked := context.Locked[id] - if locked == false { - new_locks = append(new_locks, node) - context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) - node.Lock.RLock() - } - } - - node_permissions, node_exists := new_permissions[id] - if node_exists == false { - node_permissions = ACLInfo{Node: node, Resources: []string{}} - } - - for _, resource := range(request.Resources) { - already_granted := false - for _, granted := range(node_permissions.Resources) { - if resource == granted { - already_granted = true - } - } - - if already_granted == false { - err := Allowed(context, principal.ID, fmt.Sprintf("%s.read", resource), node) - if err != nil { - for _, n := range(new_locks) { - context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) - n.Lock.RUnlock() - } - return err - } - } - } - new_permissions[id] = node_permissions - } - - for _, node := range(new_locks) { - context.Locked[node.ID] = node - } - - context.Permissions[principal.ID] = new_permissions - - err = state_fn(context) - - context.Permissions[principal.ID] = princ_permissions - - for _, node := range(new_locks) { - context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID.String()) - delete(context.Locked, node.ID) - node.Lock.RUnlock() - } - - return err -} - -// Add nodes to an existing write context and call nodes_fn with nodes locked for read -// If context is nil -func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error { - if principal == nil || new_nodes == nil || state_fn == nil { - return fmt.Errorf("nil passed to UpdateStates") - } - - err := ValidateStateContext(context, "write", false) - if err != nil { - return err - } - - final := false - if context.Started == false { - context.Started = true - final = true - } - - new_locks := []*Node{} - _, princ_locked := context.Locked[principal.ID] - if princ_locked == false { - new_locks = append(new_locks, principal) - context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", principal.ID.String()) - principal.Lock.Lock() - } - - princ_permissions, princ_exists := context.Permissions[principal.ID] - new_permissions := ACLMap{} - if princ_exists == true { - for id, info := range(princ_permissions) { - new_permissions[id] = info - } - } - - for _, request := range(new_nodes) { - node := request.Node - if node == nil { - return fmt.Errorf("node in request list is nil") - } - id := node.ID - - if id != principal.ID { - _, locked := context.Locked[id] - if locked == false { - new_locks = append(new_locks, node) - context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) - node.Lock.Lock() - } - } - - node_permissions, node_exists := new_permissions[id] - if node_exists == false { - node_permissions = ACLInfo{Node: node, Resources: []string{}} - } - - for _, resource := range(request.Resources) { - already_granted := false - for _, granted := range(node_permissions.Resources) { - if resource == granted { - already_granted = true - } - } - - if already_granted == false { - err := Allowed(context, principal.ID, fmt.Sprintf("%s.write", resource), node) - if err != nil { - for _, n := range(new_locks) { - context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) - n.Lock.Unlock() - } - return err - } - } - } - new_permissions[id] = node_permissions - } - - for _, node := range(new_locks) { - context.Locked[node.ID] = node - } - - context.Permissions[principal.ID] = new_permissions - - err = state_fn(context) - - if final == true { - context.Finished = true - if err == nil { - err = WriteNodes(context) - } - for id, node := range(context.Locked) { - context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) - node.Lock.Unlock() - } - } - - return err -} - diff --git a/node_test.go b/node_test.go index 6baab96..384c875 100644 --- a/node_test.go +++ b/node_test.go @@ -10,14 +10,10 @@ func TestNodeDB(t *testing.T) { err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) - node := NewNode(ctx, RandID(), node_type, NewGroupExt(nil)) + node := NewNode(ctx, RandID(), node_type, nil, NewGroupExt(nil)) - context := NewWriteContext(ctx) - err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error { - ser, err := node.Serialize() - ctx.Log.Logf("test", "NODE_SER: %+v", ser) - return err - }) + ser, err := node.Serialize() + ctx.Log.Logf("test", "NODE_SER: %+v", ser) fatalErr(t, err) ctx.Nodes = NodeMap{} diff --git a/policy.go b/policy.go index 2b13c8a..5c80318 100644 --- a/policy.go +++ b/policy.go @@ -7,15 +7,15 @@ import ( type Policy interface { Serializable[PolicyType] - Allows(context *StateContext, principal_id NodeID, action string, node *Node) error + Allows(principal_id NodeID, action string, node *Node) error } //TODO: Update with change from principal *Node to principal_id so sane policies can still be made -func (policy *AllNodesPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { +func (policy *AllNodesPolicy) Allows(principal_id NodeID, action string, node *Node) error { return policy.Actions.Allows(action) } -func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { +func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *Node) error { for id, actions := range(policy.NodeActions) { if id != principal_id { continue @@ -29,13 +29,13 @@ func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) } -func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { +func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action string, node *Node) error { lockable_ext, err := GetExt[*LockableExt](node) if err != nil { return err } - for id, _ := range(lockable_ext.Requirements) { + for _, id := range(lockable_ext.Requirements) { if id == principal_id { return policy.Actions.Allows(action) } @@ -44,36 +44,6 @@ func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id No return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) } -func (policy *ParentOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { - thread_ext, err := GetExt[*ThreadExt](node) - if err != nil { - return err - } - - if thread_ext.Parent != nil { - if thread_ext.Parent.ID == principal_id { - return policy.Actions.Allows(action) - } - } - - return fmt.Errorf("%s is not a parent of %s", principal_id, node.ID) -} - -func (policy *ChildOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { - thread_ext, err := GetExt[*ThreadExt](node) - if err != nil { - return err - } - - for id, _ := range(thread_ext.Children) { - if id == principal_id { - return policy.Actions.Allows(action) - } - } - - return fmt.Errorf("%s is not a child of %s", principal_id, node.ID) -} - const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") type RequirementOfPolicy struct { AllNodesPolicy @@ -88,14 +58,6 @@ func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy { } } -const ChildOfPolicyType = PolicyType("CHILD_OF") -type ChildOfPolicy struct { - AllNodesPolicy -} -func (policy *ChildOfPolicy) Type() PolicyType { - return ChildOfPolicyType -} - type Actions []string func (actions Actions) Allows(action string) error { @@ -153,26 +115,6 @@ func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, } } -func NewChildOfPolicy(actions Actions) ChildOfPolicy { - return ChildOfPolicy{ - AllNodesPolicy: NewAllNodesPolicy(actions), - } -} - -const ParentOfPolicyType = PolicyType("PARENT_OF") -type ParentOfPolicy struct { - AllNodesPolicy -} -func (policy *ParentOfPolicy) Type() PolicyType { - return ParentOfPolicyType -} - -func NewParentOfPolicy(actions Actions) ParentOfPolicy { - return ParentOfPolicy{ - AllNodesPolicy: NewAllNodesPolicy(actions), - } -} - func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { if node_actions == nil { node_actions = NodeActions{} @@ -268,18 +210,6 @@ func NewACLExtContext() *ACLExtContext { return &policy, nil }), }, - ParentOfPolicyType: PolicyInfo{ - Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ - policy := NewParentOfPolicy(actions) - return &policy, nil - }), - }, - ChildOfPolicyType: PolicyInfo{ - Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ - policy := NewChildOfPolicy(actions) - return &policy, nil - }), - }, RequirementOfPolicyType: PolicyInfo{ Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ policy := NewRequirementOfPolicy(actions) @@ -307,8 +237,7 @@ func (ext *ACLExt) Serialize() ([]byte, error) { }, "", " ") } -func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error { - return nil +func (ext *ACLExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { } func NewACLExt(policies ...Policy) *ACLExt { @@ -362,11 +291,11 @@ func (ext *ACLExt) Type() ExtType { } // Check if the extension allows the principal to perform action on node -func (ext *ACLExt) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { - context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) +func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action string, node *Node) error { + ctx.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) errs := []error{} for _, policy := range(ext.Policies) { - err := policy.Allows(context, principal_id, action, node) + err := policy.Allows(principal_id, action, node) if err == nil { return nil } diff --git a/signal.go b/signal.go index 033a81a..030945d 100644 --- a/signal.go +++ b/signal.go @@ -55,8 +55,8 @@ func NewDirectSignal(signal_type SignalType) BaseSignal { return NewBaseSignal(signal_type, Direct) } -var AbortSignal = NewBaseSignal("abort", Down) -var StopSignal = NewBaseSignal("stop", Down) +const StopSignalType = SignalType("STOP") +var StopSignal = NewDownSignal(StopSignalType) type IDSignal struct { BaseSignal diff --git a/user.go b/user.go index 0be303c..2b66b1a 100644 --- a/user.go +++ b/user.go @@ -20,8 +20,8 @@ type ECDHExtJSON struct { Shared []byte `json:"shared"` } -func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error { - return nil +func (ext *ECDHExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { + return } const ECDHExtType = ExtType("ECDH") @@ -115,6 +115,6 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { return NewGroupExt(members), nil } -func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error { - return nil +func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { + return }