diff --git a/context.go b/context.go index fd71a14..b51d62b 100644 --- a/context.go +++ b/context.go @@ -86,8 +86,8 @@ func (ctx * Context) RegisterNodeType(def NodeDef) error { ctx.Types[type_hash] = def node_type := reflect.TypeOf((*Node)(nil)).Elem() - lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() - thread_type := reflect.TypeOf((*Thread)(nil)).Elem() + lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem() + thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem() if def.Reflect.Implements(node_type) { ctx.GQL.ValidNodes[def.Reflect] = def.GQLType @@ -154,7 +154,7 @@ func NewGQLContext() GQLContext { Query: query, Mutation: mutation, Subscription: subscription, - BaseNodeType: GQLTypeGraphNode.Type, + BaseNodeType: GQLTypeSimpleNode.Type, BaseLockableType: GQLTypeSimpleLockable.Type, BaseThreadType: GQLTypeSimpleThread.Type, } @@ -171,15 +171,19 @@ func NewContext(db * badger.DB, log Logger) * Context { Types: map[uint64]NodeDef{}, } - err := ctx.RegisterNodeType(NewNodeDef((*GraphNode)(nil), LoadGraphNode, GQLTypeGraphNode.Type)) + err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type)) if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*SimpleLockable)(nil), LoadSimpleLockable, GQLTypeSimpleLockable.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type)) if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*SimpleThread)(nil), LoadSimpleThread, GQLTypeSimpleThread.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type)) + if err != nil { + panic(err) + } + err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type)) if err != nil { panic(err) } @@ -191,19 +195,19 @@ func NewContext(db * badger.DB, log Logger) * Context { if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeGraphNode.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type)) if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeGraphNode.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeSimpleNode.Type)) if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type)) + err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) if err != nil { panic(err) } diff --git a/gql.go b/gql.go index 8faa230..d6b1025 100644 --- a/gql.go +++ b/gql.go @@ -629,7 +629,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } type GQLThread struct { - SimpleThread + Thread tcp_listener net.Listener http_server *http.Server http_done *sync.WaitGroup @@ -639,6 +639,37 @@ type GQLThread struct { Users map[NodeID]*User Key *ecdsa.PrivateKey ECDH ecdh.Curve + SubscribeLock sync.Mutex + SubscribeListeners []chan GraphSignal +} + +func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal { + thread.SubscribeLock.Lock() + defer thread.SubscribeLock.Unlock() + + new_listener := make(chan GraphSignal, buffer) + thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener) + + return new_listener +} + +func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error { + active_listeners := []chan GraphSignal{} + thread.SubscribeLock.Lock() + for _, listener := range(thread.SubscribeListeners) { + select { + case listener <- signal: + active_listeners = append(active_listeners, listener) + default: + go func(listener chan GraphSignal) { + listener <- NewDirectSignal("Channel Closed") + close(listener) + }(listener) + } + } + thread.SubscribeListeners = active_listeners + thread.SubscribeLock.Unlock() + return thread.Thread.Process(context, signal) } func (thread * GQLThread) Type() NodeType { @@ -650,17 +681,8 @@ func (thread * GQLThread) Serialize() ([]byte, error) { return json.MarshalIndent(&thread_json, "", " ") } -func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) { - var info ParentThreadInfo - err := json.Unmarshal(data, &info) - if err != nil { - return nil, err - } - return &info, nil -} - type GQLThreadJSON struct { - SimpleThreadJSON + ThreadJSON Listen string `json:"listen"` Users []string `json:"users"` Key []byte `json:"key"` @@ -686,7 +708,7 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{ } func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { - thread_json := NewSimpleThreadJSON(&thread.SimpleThread) + thread_json := NewThreadJSON(&thread.Thread) ser_key, err := x509.MarshalECPrivateKey(thread.Key) if err != nil { @@ -701,7 +723,7 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { } return GQLThreadJSON{ - SimpleThreadJSON: thread_json, + ThreadJSON: thread_json, Listen: thread.Listen, Users: users, Key: ser_key, @@ -744,7 +766,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e thread.Users[id] = user.(*User) } - err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) + err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes) if err != nil { return nil, err } @@ -793,8 +815,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh tls_key = ssl_key_pem } return GQLThread{ - SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), + Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers), Listen: listen, + SubscribeListeners: []chan GraphSignal{}, Users: map[NodeID]*User{}, http_done: &sync.WaitGroup{}, Key: key, @@ -806,40 +829,23 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh var gql_actions ThreadActions = ThreadActions{ "wait": ThreadWait, - "restore": func(ctx * Context, thread Thread) (string, error) { - ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID()) - // Restore all the threads that have "Start" as true and arent in the "finished" state - err := ThreadRestore(ctx, thread, false) - if err != nil { - return "", err - } - return "start_server", nil + "restore": func(ctx *Context, node ThreadNode) (string, error) { + return "start_server", ThreadRestore(ctx, node, false) }, - "start": func(ctx * Context, thread Thread) (string, error) { - ctx.Log.Logf("gql", "GQL_START") - err := ThreadStart(ctx, thread) + "start": func(ctx * Context, node ThreadNode) (string, error) { + _, err := ThreadStart(ctx, node) if err != nil { return "", err } - // Start all the threads that have "Start" as true and arent in the "finished" state - err = ThreadRestore(ctx, thread, true) - if err != nil { - return "", err - } - return "start_server", nil + return "start_server", ThreadRestore(ctx, node, true) }, - "start_server": func(ctx * Context, thread Thread) (string, error) { - server, ok := thread.(*GQLThread) - if ok == false { - return "", fmt.Errorf("GQL_THREAD_START: %s is not GQLThread, %+v", thread.ID(), thread.State()) - } + "start_server": func(ctx * Context, node ThreadNode) (string, error) { + gql_thread := node.(*GQLThread) - ctx.Log.Logf("gql", "GQL_START_SERVER") - // Serve the GQL http and ws handlers mux := http.NewServeMux() - mux.HandleFunc("/auth", AuthHandler(ctx, server)) - mux.HandleFunc("/gql", GQLHandler(ctx, server)) - mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server)) + mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread)) + mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread)) + mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread)) // Server a graphiql interface(TODO make configurable whether to start this) mux.HandleFunc("/graphiql", GraphiQLHandler()) @@ -849,7 +855,7 @@ var gql_actions ThreadActions = ThreadActions{ mux.Handle("/site/", http.StripPrefix("/site", fs)) http_server := &http.Server{ - Addr: server.Listen, + Addr: gql_thread.Listen, Handler: mux, } @@ -858,7 +864,7 @@ var gql_actions ThreadActions = ThreadActions{ return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) } - cert, err := tls.X509KeyPair(server.tls_cert, server.tls_key) + cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key) if err != nil { return "", err } @@ -870,23 +876,23 @@ var gql_actions ThreadActions = ThreadActions{ listener := tls.NewListener(l, &config) - server.http_done.Add(1) - go func(server *GQLThread) { - defer server.http_done.Done() + gql_thread.http_done.Add(1) + go func(gql_thread *GQLThread) { + defer gql_thread.http_done.Done() err := http_server.Serve(listener) if err != http.ErrServerClosed { panic(fmt.Sprintf("Failed to start gql server: %s", err)) } - }(server) + }(gql_thread) context := NewWriteContext(ctx) - err = UpdateStates(context, server, NewLockMap( - NewLockInfo(server, []string{"http_server"}), + err = UpdateStates(context, gql_thread, NewLockMap( + NewLockInfo(gql_thread, []string{"http_server"}), ), func(context *StateContext) error { - server.tcp_listener = listener - server.http_server = http_server + gql_thread.tcp_listener = listener + gql_thread.http_server = http_server return nil }) @@ -895,24 +901,24 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = Signal(context, server, server, NewStatusSignal("server_started", server.ID())) + err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID())) if err != nil { return "", err } return "wait", nil }, - "finish": func(ctx *Context, thread Thread) (string, error) { - server := thread.(*GQLThread) - server.http_server.Shutdown(context.TODO()) - server.http_done.Wait() - return "", ThreadFinish(ctx, thread) + "finish": func(ctx *Context, node ThreadNode) (string, error) { + gql_thread := node.(*GQLThread) + gql_thread.http_server.Shutdown(context.TODO()) + gql_thread.http_done.Wait() + return ThreadFinish(ctx, node) }, } var gql_handlers ThreadHandlers = ThreadHandlers{ - "child_linked": ThreadParentChildLinked, - "start_child": ThreadParentStartChild, + "child_linked": ThreadChildLinked, + "start_child": ThreadStartChild, "abort": ThreadAbort, "stop": ThreadStop, } diff --git a/gql_mutation.go b/gql_mutation.go index ebbcdc8..2ad9f84 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -28,7 +28,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field { err = UseStates(context, ctx.User, NewLockMap( NewLockInfo(ctx.Server, []string{"children"}), ), func(context *StateContext) (error){ - node = FindChild(context, ctx.User, ctx.Server, id) + node = FindChild(context, ctx.User, &ctx.Server.Thread, id) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } @@ -86,13 +86,13 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ err = UseStates(context, ctx.User, NewLockMap( NewLockInfo(ctx.Server, []string{"children"}), ), func(context *StateContext) error { - node := FindChild(context, ctx.User, ctx.Server, parent_id) - if node == nil { - return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) + parent := FindChild(context, ctx.User, &ctx.Server.Thread, parent_id) + if parent == nil { + return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID()) } signal = NewStartChildSignal(child_id, action) - return Signal(context, node, ctx.User, signal) + return Signal(context, ctx.User, parent, signal) }) if err != nil { return nil, err diff --git a/gql_resolvers.go b/gql_resolvers.go index 9d4be7a..148e7a0 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -105,15 +105,15 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Thread) + node, ok := p.Source.(*Thread) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Thread") } - var parent Thread = nil + var parent ThreadNode = nil context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error { - parent = node.Parent() + parent = node.ThreadHandle().Parent return nil }) @@ -130,7 +130,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Thread) + node, ok := p.Source.(ThreadNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Thread") } @@ -138,7 +138,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { var state string context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { - state = node.State() + state = node.ThreadHandle().StateName return nil }) @@ -155,15 +155,20 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Thread) + node, ok := p.Source.(ThreadNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Thread") } - var children []Thread = nil + var children []ThreadNode = nil context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { - children = node.Children() + children = make([]ThreadNode, len(node.ThreadHandle().Children)) + i := 0 + for _, info := range(node.ThreadHandle().Children) { + children[i] = info.Child + i += 1 + } return nil }) @@ -180,7 +185,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Lockable) + node, ok := p.Source.(LockableNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Lockable") } @@ -188,7 +193,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { name := "" context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error { - name = node.Name() + name = node.LockableHandle().Name return nil }) @@ -205,15 +210,20 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Lockable) + node, ok := p.Source.(LockableNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Lockable") } - var requirements []Lockable = nil + var requirements []LockableNode = nil context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error { - requirements = node.Requirements() + requirements = make([]LockableNode, len(node.LockableHandle().Requirements)) + i := 0 + for _, req := range(node.LockableHandle().Requirements) { + requirements[i] = req + i += 1 + } return nil }) @@ -230,15 +240,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Lockable) + node, ok := p.Source.(LockableNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Lockable") } - var dependencies []Lockable = nil + var dependencies []LockableNode = nil context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error { - dependencies = node.Dependencies() + dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies)) + i := 0 + for _, dep := range(node.LockableHandle().Dependencies) { + dependencies[i] = dep + i += 1 + } return nil }) @@ -255,7 +270,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { return nil, err } - node, ok := p.Source.(Lockable) + node, ok := p.Source.(LockableNode) if ok == false || node == nil { return nil, fmt.Errorf("Failed to cast source to Lockable") } @@ -263,7 +278,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { var owner Node = nil context := NewReadContext(ctx.Context) err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error { - owner = node.Owner() + owner = node.LockableHandle().Owner return nil }) diff --git a/gql_subscribe.go b/gql_subscribe.go index 04be99c..3b4cef8 100644 --- a/gql_subscribe.go +++ b/gql_subscribe.go @@ -24,7 +24,7 @@ func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *G c := make(chan interface{}) go func(c chan interface{}, server *GQLThread) { ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START") - sig_c := UpdateChannel(server, 1, RandID()) + sig_c := server.NewSubscriptionChannel(1) if send_nil == true { sig_c <- nil } diff --git a/gql_test.go b/gql_test.go index 877637a..473d9cb 100644 --- a/gql_test.go +++ b/gql_test.go @@ -20,46 +20,37 @@ import ( func TestGQLDBLoad(t * testing.T) { ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"}) - l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") - l1 := &l1_r + l1 := NewListener(RandID(), "Test Lockable 1") ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) - t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - t1 := &t1_r + t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String()) listen_id := RandID() ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String()) - update_channel := UpdateChannel(t1, 10, listen_id) u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - u1_shared := []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67} - - u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"}) - u1 := &u1_r + u1 := NewUser("Test User", time.Now(), &u1_key.PublicKey, []byte{}, []string{"gql"}) ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String()) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) - gql := &gql_r + gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String()) // Policy to allow gql to perform all action on all resources - p1_r := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ + p1 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ gql.ID(): NewNodeActions(nil, []string{"*"}), }) - p1 := &p1_r - p2_r := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{ + p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{ "signal": []string{"status"}, }, nil)) - p2 := &p2_r context := NewWriteContext(ctx) - err = UpdateStates(context, gql, LockMap{ - p1.ID(): LockInfo{p1, nil}, - p2.ID(): LockInfo{p2, nil}, + err = UpdateStates(context, &gql, LockMap{ + p1.ID(): LockInfo{&p1, nil}, + p2.ID(): LockInfo{&p2, nil}, }, func(context *StateContext) error { return nil }) @@ -67,46 +58,48 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String()) ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String()) - err = AttachPolicies(ctx, gql, p1, p2) + err = AttachPolicies(ctx, &gql.SimpleNode, &p1, &p2) fatalErr(t, err) - err = AttachPolicies(ctx, l1, p1, p2) + err = AttachPolicies(ctx, &l1.SimpleNode, &p1, &p2) fatalErr(t, err) - err = AttachPolicies(ctx, t1, p1, p2) + err = AttachPolicies(ctx, &t1.SimpleNode, &p1, &p2) fatalErr(t, err) - err = AttachPolicies(ctx, u1, p1, p2) + err = AttachPolicies(ctx, &u1.SimpleNode, &p1, &p2) fatalErr(t, err) info := NewParentThreadInfo(true, "start", "restore") context = NewWriteContext(ctx) - err = UpdateStates(context, gql, NewLockMap( - NewLockInfo(gql, []string{"users"}), + err = UpdateStates(context, &gql, NewLockMap( + NewLockInfo(&gql, []string{"users"}), ), func(context *StateContext) error { - gql.Users[KeyID(&u1_key.PublicKey)] = u1 + gql.Users[KeyID(&u1_key.PublicKey)] = &u1 - err := LinkThreads(context, gql, gql, t1, &info) + err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{ + "parent": &info, + }}) if err != nil { return err } - return LinkLockables(context, gql, gql, []Lockable{l1}) + return LinkLockables(context, &gql, &gql, []LockableNode{&l1}) }) fatalErr(t, err) context = NewReadContext(ctx) - err = Signal(context, gql, gql, NewStatusSignal("child_linked", t1.ID())) + err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID())) fatalErr(t, err) context = NewReadContext(ctx) - err = Signal(context, gql, gql, AbortSignal) + err = Signal(context, &gql, &gql, AbortSignal) fatalErr(t, err) - err = ThreadLoop(ctx, gql, "start") + err = ThreadLoop(ctx, &gql, "start") if errors.Is(err, ThreadAbortedError) == false { fatalErr(t, err) } - (*GraphTester)(t).WaitForStatus(ctx, update_channel, "aborted", 100*time.Millisecond, "Didn't receive aborted on update_channel") + (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") context = NewReadContext(ctx) - err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { + err = UseStates(context, &gql, LockList([]Node{&gql, &u1}, nil), func(context *StateContext) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -116,18 +109,15 @@ func TestGQLDBLoad(t * testing.T) { gql_loaded, err := LoadNode(ctx, gql.ID()) fatalErr(t, err) - var t1_loaded *SimpleThread = nil - - var update_channel_2 chan GraphSignal + var l1_loaded *Listener = nil context = NewReadContext(ctx) - err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error { + err = UseStates(context, gql_loaded, NewLockInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error { ser, err := gql_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) + dependency := gql_loaded.(*GQLThread).Thread.Dependencies[l1.ID()].(*Listener) + l1_loaded = dependency u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] - child := gql_loaded.(Thread).Children()[0].(*SimpleThread) - t1_loaded = child - update_channel_2 = UpdateChannel(t1_loaded, 10, RandID()) - err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error { + err = UseStates(context, gql_loaded, NewLockInfo(u_loaded, nil), func(context *StateContext) error { ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err @@ -136,9 +126,9 @@ func TestGQLDBLoad(t * testing.T) { return err }) - err = ThreadLoop(ctx, gql_loaded.(Thread), "start") + err = ThreadLoop(ctx, gql_loaded.(ThreadNode), "start") fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") + (*GraphTester)(t).WaitForStatus(ctx, l1_loaded.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") } @@ -153,23 +143,19 @@ func TestGQLAuth(t * testing.T) { gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r - // p1 not written to DB, TODO: update write to follow links maybe - context := NewWriteContext(ctx) - err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error { - return gql_t.AddPolicy(p1) - }) + l1 := NewListener(RandID(), "GQL Thread") + err = AttachPolicies(ctx, &l1.SimpleNode, p1) + fatalErr(t, err) + err = AttachPolicies(ctx, &gql_t.SimpleNode, p1) done := make(chan error, 1) - var update_channel chan GraphSignal - context = NewReadContext(ctx) - err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error { - update_channel = UpdateChannel(gql_t, 10, NodeID{}) - return nil - }) + context := NewWriteContext(ctx) + err = LinkLockables(context, gql_t, gql_t, []LockableNode{&l1}) fatalErr(t, err) - go func(done chan error, thread Thread) { + + go func(done chan error, thread ThreadNode) { timeout := time.After(2*time.Second) select { case <-timeout: @@ -182,8 +168,8 @@ func TestGQLAuth(t * testing.T) { fatalErr(t, err) }(done, gql_t) - go func(thread Thread){ - (*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") + go func(thread ThreadNode){ + (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "server_started", 100*time.Millisecond, "Server didn't start") port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port ctx.Log.Logf("test", "GQL_PORT: %d", port) diff --git a/gql_types.go b/gql_types.go index ecbf1db..4a88fc0 100644 --- a/gql_types.go +++ b/gql_types.go @@ -142,9 +142,9 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object { return gql_type_simple_lockable }, nil) -var GQLTypeGraphNode = NewSingleton(func() *graphql.Object { +var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object { object := graphql.NewObject(graphql.ObjectConfig{ - Name: "GraphNode", + Name: "SimpleNode", Interfaces: []*graphql.Interface{ GQLInterfaceNode.Type, }, diff --git a/lockable.go b/lockable.go index 09c3beb..33c3417 100644 --- a/lockable.go +++ b/lockable.go @@ -5,62 +5,75 @@ import ( "encoding/json" ) -// A Lockable represents a Node that can be locked and hold other Nodes locks -type Lockable interface { - // All Lockables are nodes +type Listener struct { + Lockable + Chan chan GraphSignal +} + +func (node *Listener) Type() NodeType { + return NodeType("listener") +} + +func (node *Listener) Process(context *StateContext, signal GraphSignal) error { + select { + case node.Chan <- signal: + default: + return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal) + } + return node.Lockable.Process(context, signal) +} + +const LISTENER_CHANNEL_BUFFER = 1024 +func NewListener(id NodeID, name string) Listener { + return Listener{ + Lockable: NewLockable(id, name), + Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER), + } +} + +func LoadListener(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j LockableJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + listener := NewListener(id, j.Name) + nodes[id] = &listener + + err = RestoreLockable(ctx, &listener.Lockable, j, nodes) + if err != nil { + return nil, err + } + + return &listener, nil +} + +type LockableNode interface { Node - //// State Modification Function - // Record that lockable was returned to it's owner and is no longer held by this Node - // Returns the previous owner of the lockable - RecordUnlock(lockable Lockable) Lockable - // Record that lockable was locked by this node, and that it should be returned to last_owner - RecordLock(lockable Lockable, last_owner Lockable) - // Link a requirement to this Node - AddRequirement(requirement Lockable) - // Remove a requirement linked to this Node - RemoveRequirement(requirement Lockable) - // Link a dependency to this Node - AddDependency(dependency Lockable) - // Remove a dependency linked to this Node - RemoveDependency(dependency Lockable) - // - SetOwner(new_owner Lockable) - - //// State Reading Functions - Name() string - // Called when new_owner wants to take lockable's lock but it's owned by this node - // A true return value means that the lock can be passed - AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool - // Get all the linked requirements to this node - Requirements() []Lockable - // Get all the linked dependencies to this node - Dependencies() []Lockable - // Get the node's Owner - Owner() Lockable - // Called during the lock process after locking the state and before updating the Node's state - // a non-nil return value will abort the lock attempt - CanLock(new_owner Lockable) error - // Called during the unlock process after locking the state and before updating the Node's state - // a non-nil return value will abort the unlock attempt - CanUnlock(old_owner Lockable) error + LockableHandle() *Lockable } -// SimpleLockable is a simple Lockable implementation that can be embedded into more complex structures -type SimpleLockable struct { - GraphNode - name string - owner Lockable - requirements []Lockable - dependencies []Lockable - locks_held map[NodeID]Lockable +// Lockable is a simple Lockable implementation that can be embedded into more complex structures +type Lockable struct { + SimpleNode + Name string + Owner LockableNode + Requirements map[NodeID]LockableNode + Dependencies map[NodeID]LockableNode + LocksHeld map[NodeID]LockableNode } -func (state * SimpleLockable) Type() NodeType { - return NodeType("simple_lockable") +func (lockable *Lockable) LockableHandle() *Lockable { + return lockable } -type SimpleLockableJSON struct { - GraphNodeJSON +func (lockable *Lockable) Type() NodeType { + return NodeType("lockable") +} + +type LockableJSON struct { + SimpleNodeJSON Name string `json:"name"` Owner string `json:"owner"` Dependencies []string `json:"dependencies"` @@ -68,29 +81,33 @@ type SimpleLockableJSON struct { LocksHeld map[string]string `json:"locks_held"` } -func (lockable * SimpleLockable) Serialize() ([]byte, error) { - lockable_json := NewSimpleLockableJSON(lockable) +func (lockable *Lockable) Serialize() ([]byte, error) { + lockable_json := NewLockableJSON(lockable) return json.MarshalIndent(&lockable_json, "", " ") } -func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { - requirement_ids := make([]string, len(lockable.requirements)) - for i, requirement := range(lockable.requirements) { - requirement_ids[i] = requirement.ID().String() +func NewLockableJSON(lockable *Lockable) LockableJSON { + requirement_ids := make([]string, len(lockable.Requirements)) + req_n := 0 + for id, _ := range(lockable.Requirements) { + requirement_ids[req_n] = id.String() + req_n++ } - dependency_ids := make([]string, len(lockable.dependencies)) - for i, dependency := range(lockable.dependencies) { - dependency_ids[i] = dependency.ID().String() + dependency_ids := make([]string, len(lockable.Dependencies)) + dep_n := 0 + for id, _ := range(lockable.Dependencies) { + dependency_ids[dep_n] = id.String() + dep_n++ } owner_id := "" - if lockable.owner != nil { - owner_id = lockable.owner.ID().String() + if lockable.Owner != nil { + owner_id = lockable.Owner.ID().String() } locks_held := map[string]string{} - for lockable_id, node := range(lockable.locks_held) { + for lockable_id, node := range(lockable.LocksHeld) { if node == nil { locks_held[lockable_id.String()] = "" } else { @@ -98,11 +115,11 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { } } - node_json := NewGraphNodeJSON(&lockable.GraphNode) + node_json := NewSimpleNodeJSON(&lockable.SimpleNode) - return SimpleLockableJSON{ - GraphNodeJSON: node_json, - Name: lockable.name, + return LockableJSON{ + SimpleNodeJSON: node_json, + Name: lockable.Name, Owner: owner_id, Dependencies: dependency_ids, Requirements: requirement_ids, @@ -110,114 +127,29 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { } } -func (lockable * SimpleLockable) Name() string { - return lockable.name -} - -func (lockable * SimpleLockable) RecordUnlock(l Lockable) Lockable { +func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode { lockable_id := l.ID() - last_owner, exists := lockable.locks_held[lockable_id] + last_owner, exists := lockable.LocksHeld[lockable_id] if exists == false { panic("Attempted to take a get the original lock holder of a lockable we don't own") } - delete(lockable.locks_held, lockable_id) + delete(lockable.LocksHeld, lockable_id) return last_owner } -func (lockable * SimpleLockable) RecordLock(l Lockable, last_owner Lockable) { +func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) { lockable_id := l.ID() - _, exists := lockable.locks_held[lockable_id] + _, exists := lockable.LocksHeld[lockable_id] if exists == true { panic("Attempted to lock a lockable we're already holding(lock cycle)") } - lockable.locks_held[lockable_id] = last_owner -} - -// Nothing can take a lock from a simple lockable -func (lockable * SimpleLockable) AllowedToTakeLock(l Lockable, new_owner Lockable) bool { - return false -} - -func (lockable * SimpleLockable) Owner() Lockable { - return lockable.owner -} - -func (lockable * SimpleLockable) SetOwner(owner Lockable) { - lockable.owner = owner -} - -func (lockable * SimpleLockable) Requirements() []Lockable { - return lockable.requirements -} - -func (lockable * SimpleLockable) AddRequirement(requirement Lockable) { - if requirement == nil { - panic("Will not connect nil to the DAG") - } - lockable.requirements = append(lockable.requirements, requirement) -} - -func (lockable * SimpleLockable) Dependencies() []Lockable { - return lockable.dependencies -} - -func (lockable * SimpleLockable) AddDependency(dependency Lockable) { - if dependency == nil { - panic("Will not connect nil to the DAG") - } - - lockable.dependencies = append(lockable.dependencies, dependency) -} - -func (lockable * SimpleLockable) RemoveDependency(dependency Lockable) { - idx := -1 - - for i, dep := range(lockable.dependencies) { - if dep.ID() == dependency.ID() { - idx = i - break - } - } - - if idx == -1 { - panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), lockable.Name())) - } - - dep_len := len(lockable.dependencies) - lockable.dependencies[idx] = lockable.dependencies[dep_len-1] - lockable.dependencies = lockable.dependencies[0:(dep_len-1)] -} - -func (lockable * SimpleLockable) RemoveRequirement(requirement Lockable) { - idx := -1 - for i, req := range(lockable.requirements) { - if req.ID() == requirement.ID() { - idx = i - break - } - } - - if idx == -1 { - panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), lockable.Name())) - } - - req_len := len(lockable.requirements) - lockable.requirements[idx] = lockable.requirements[req_len-1] - lockable.requirements = lockable.requirements[0:(req_len-1)] -} - -func (lockable * SimpleLockable) CanLock(new_owner Lockable) error { - return nil -} - -func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { - return nil + lockable.LocksHeld[lockable_id] = last_owner } // Assumed that lockable is already locked for signal -func (lockable * SimpleLockable) Process(context *StateContext, princ Node, signal GraphSignal) error { - err := lockable.GraphNode.Process(context, princ, signal) +func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error { + err := lockable.SimpleNode.Process(context, signal) if err != nil { return err } @@ -227,26 +159,26 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign err = UseStates(context, lockable, NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error { owner_sent := false - for _, dependency := range(lockable.dependencies) { + for _, dependency := range(lockable.Dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) Signal(context, dependency, lockable, signal) - if lockable.owner != nil { - if dependency.ID() == lockable.owner.ID() { + if lockable.Owner != nil { + if dependency.ID() == lockable.Owner.ID() { owner_sent = true } } } - if lockable.owner != nil && owner_sent == false { - if lockable.owner.ID() != lockable.ID() { - context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - return Signal(context, lockable.owner, lockable, signal) + if lockable.Owner != nil && owner_sent == false { + if lockable.Owner.ID() != lockable.ID() { + context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID()) + return Signal(context, lockable.Owner, lockable, signal) } } return nil }) case Down: err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { - for _, requirement := range(lockable.requirements) { + for _, requirement := range(lockable.Requirements) { err := Signal(context, requirement, lockable, signal) if err != nil { return err @@ -265,13 +197,13 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign // Removes requirement as a requirement from lockable // Continues the write context with princ, getting requirents for lockable and dependencies for requirement // Assumes that an active write context exists with princ locked so that princ's state can be used in checks -func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error { +func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error { return UpdateStates(context, princ, LockMap{ lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}}, requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}}, }, func(context *StateContext) error { var found Node = nil - for _, req := range(lockable.Requirements()) { + for _, req := range(lockable.LockableHandle().Requirements) { if requirement.ID() == req.ID() { found = req break @@ -282,8 +214,8 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) } - requirement.RemoveDependency(lockable) - lockable.RemoveRequirement(requirement) + delete(requirement.LockableHandle().Dependencies, lockable.ID()) + delete(lockable.LockableHandle().Requirements, requirement.ID()) return nil }) @@ -291,10 +223,11 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi // Link requirements as requirements to lockable // Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements -func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error { - if lockable == nil { +func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error { + if lockable_node == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } + lockable := lockable_node.LockableHandle() if len(requirements) == 0 { return nil @@ -323,8 +256,10 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require ), func(context *StateContext) error { // Check that all the requirements can be added // If the lockable is already locked, need to lock this resource as well before we can add it - for _, requirement := range(requirements) { - for _, req := range(requirements) { + for _, requirement_node := range(requirements) { + requirement := requirement_node.LockableHandle() + for _, req_node := range(requirements) { + req := req_node.LockableHandle() if req.ID() == requirement.ID() { continue } @@ -339,22 +274,23 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require if checkIfRequirement(context, requirement, lockable) == true { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) } - if lockable.Owner() == nil { + if lockable.Owner == nil { // If the new owner isn't locked, we can add the requirement - } else if requirement.Owner() == nil { + } else if requirement.Owner == nil { // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) } else { // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement.Owner().ID() != lockable.Owner().ID() { + if requirement.Owner.ID() != lockable.Owner.ID() { return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) } } } // Update the states of the requirements - for _, requirement := range(requirements) { - requirement.AddDependency(lockable) - lockable.AddRequirement(requirement) + for _, requirement_node := range(requirements) { + requirement := requirement_node.LockableHandle() + requirement.Dependencies[lockable.ID()] = lockable_node + lockable.Requirements[lockable.ID()] = requirement_node context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) } @@ -364,8 +300,8 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require } // Must be called withing update context -func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool { - for _, c := range(cur.Requirements()) { +func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool { + for _, c := range(cur.LockableHandle().Requirements) { if c.ID() == r.ID() { return true } @@ -385,7 +321,7 @@ func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool { // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call -func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error { +func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } @@ -396,44 +332,41 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable } } - if new_owner == nil { + if new_owner_node == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") } + new_owner := new_owner_node.LockableHandle() + // Called with no requirements to lock, success if len(to_lock) == 0 { return nil } return UpdateStates(context, new_owner, NewLockMap( - LockList(to_lock, []string{"lock"}), + LockListM(to_lock, []string{"lock"}), NewLockInfo(new_owner, nil), ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_lock) { + for _, req_node := range(to_lock) { + req := req_node.LockableHandle() context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) - // Check custom lock conditions - err := req.CanLock(new_owner) - if err != nil { - return err - } - // If req is alreay locked, check that we can pass the lock - if req.Owner() != nil { - owner := req.Owner() + if req.Owner != nil { + owner := req.Owner if owner.ID() == new_owner.ID() { continue } else { err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ - return LockLockables(context, req.Requirements(), req) + return LockLockables(context, req.Requirements, req) }) if err != nil { return err } } } else { - err := LockLockables(context, req.Requirements(), req) + err := LockLockables(context, req.Requirements, req) if err != nil { return err } @@ -441,19 +374,20 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable } // At this point state modification will be started, so no errors can be returned - for _, req := range(to_lock) { - old_owner := req.Owner() + for _, req_node := range(to_lock) { + req := req_node.LockableHandle() + old_owner := req.Owner // If the lockable was previously unowned, update the state if old_owner == nil { context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) - req.SetOwner(new_owner) + req.Owner = new_owner_node new_owner.RecordLock(req, old_owner) // Otherwise if the new owner already owns it, no need to update state } else if old_owner.ID() == new_owner.ID() { context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) // Otherwise update the state } else { - req.SetOwner(new_owner) + req.Owner = new_owner new_owner.RecordLock(req, old_owner) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) } @@ -463,7 +397,7 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable } -func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error { +func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } @@ -474,48 +408,46 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock } } - if old_owner == nil { + if old_owner_node == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") } + old_owner := old_owner_node.LockableHandle() + // Called with no requirements to unlock, success if len(to_unlock) == 0 { return nil } return UpdateStates(context, old_owner, NewLockMap( - LockList(to_unlock, []string{"lock"}), + LockListM(to_unlock, []string{"lock"}), NewLockInfo(old_owner, nil), ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements - for _, req := range(to_unlock) { + for _, req_node := range(to_unlock) { + req := req_node.LockableHandle() context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) // Check if the owner is correct - if req.Owner() != nil { - if req.Owner().ID() != old_owner.ID() { + if req.Owner != nil { + if req.Owner.ID() != old_owner.ID() { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) } } else { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) } - // Check custom unlock conditions - err := req.CanUnlock(old_owner) - if err != nil { - return err - } - - err = UnlockLockables(context, req.Requirements(), req) + err := UnlockLockables(context, req.Requirements, req) if err != nil { return err } } // At this point state modification will be started, so no errors can be returned - for _, req := range(to_unlock) { + for _, req_node := range(to_unlock) { + req := req_node.LockableHandle() new_owner := old_owner.RecordUnlock(req) - req.SetOwner(new_owner) + req.Owner = new_owner if new_owner == nil { context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) } else { @@ -527,18 +459,18 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock }) } -// Load function for SimpleLockable -func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j SimpleLockableJSON +// Load function for Lockable +func LoadLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j LockableJSON err := json.Unmarshal(data, &j) if err != nil { return nil, err } - lockable := NewSimpleLockable(id, j.Name) + lockable := NewLockable(id, j.Name) nodes[id] = &lockable - err = RestoreSimpleLockable(ctx, &lockable, j, nodes) + err = RestoreLockable(ctx, &lockable, j, nodes) if err != nil { return nil, err } @@ -546,19 +478,19 @@ func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (No return &lockable, nil } -func NewSimpleLockable(id NodeID, name string) SimpleLockable { - return SimpleLockable{ - GraphNode: NewGraphNode(id), - name: name, - owner: nil, - requirements: []Lockable{}, - dependencies: []Lockable{}, - locks_held: map[NodeID]Lockable{}, +func NewLockable(id NodeID, name string) Lockable { + return Lockable{ + SimpleNode: NewSimpleNode(id), + Name: name, + Owner: nil, + Requirements: map[NodeID]LockableNode{}, + Dependencies: map[NodeID]LockableNode{}, + LocksHeld: map[NodeID]LockableNode{}, } } -// Helper function to load links when loading a struct that embeds SimpleLockable -func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSON, nodes NodeMap) error { +// Helper function to load links when loading a struct that embeds Lockable +func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes NodeMap) error { if j.Owner != "" { owner_id, err := ParseID(j.Owner) if err != nil { @@ -568,11 +500,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO if err != nil { return err } - owner, ok := owner_node.(Lockable) + owner, ok := owner_node.(LockableNode) if ok == false { return fmt.Errorf("%s is not a Lockable", j.Owner) } - lockable.SetOwner(owner) + lockable.Owner = owner } for _, dep_str := range(j.Dependencies) { @@ -584,11 +516,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO if err != nil { return err } - dep, ok := dep_node.(Lockable) + dep, ok := dep_node.(LockableNode) if ok == false { return fmt.Errorf("%+v is not a Lockable as expected", dep_node) } - lockable.AddDependency(dep) + lockable.Dependencies[dep_id] = dep } for _, req_str := range(j.Requirements) { @@ -600,11 +532,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO if err != nil { return err } - req, ok := req_node.(Lockable) + req, ok := req_node.(LockableNode) if ok == false { return fmt.Errorf("%+v is not a Lockable as expected", req_node) } - lockable.AddRequirement(req) + lockable.Requirements[req_id] = req } for l_id_str, h_str := range(j.LocksHeld) { @@ -613,12 +545,12 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO if err != nil { return err } - l_l, ok := l.(Lockable) + l_l, ok := l.(LockableNode) if ok == false { return fmt.Errorf("%s is not a Lockable", l.ID()) } - var h_l Lockable = nil + var h_l LockableNode if h_str != "" { h_id, err := ParseID(h_str) if err != nil { @@ -628,7 +560,7 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO if err != nil { return err } - h, ok := h_node.(Lockable) + h, ok := h_node.(LockableNode) if ok == false { return err } @@ -637,5 +569,5 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO lockable.RecordLock(l_l, h_l) } - return RestoreGraphNode(ctx, lockable, j.GraphNodeJSON, nodes) + return RestoreSimpleNode(ctx, &lockable.SimpleNode, j.SimpleNodeJSON, nodes) } diff --git a/node.go b/node.go index 10b8496..3e5d925 100644 --- a/node.go +++ b/node.go @@ -44,7 +44,7 @@ func KeyID(pub *ecdsa.PublicKey) NodeID { // Types are how nodes are associated with structs at runtime(and from the DB) type NodeType string func (node_type NodeType) Hash() uint64 { - hash := sha512.Sum512([]byte(node_type)) + hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type))) return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) } @@ -54,120 +54,87 @@ func RandID() NodeID { return NodeID(uuid.New()) } -// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) -// RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node type Node interface { - // State Locking interface - sync.Locker - RLock() - RUnlock() - - // Serialize the Node for the database - Serialize() ([]byte, error) - - // Nodes have an ID, type, and ACL policies ID() NodeID Type() NodeType - - Policies() map[NodeID]Policy - AddPolicy(Policy) error - RemovePolicy(Policy) error - - // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Process(context *StateContext, princ Node, signal GraphSignal) error - // Register a channel to receive updates sent to the node - RegisterChannel(id NodeID, listener chan GraphSignal) - // Unregister a channel from receiving updates sent to the node - UnregisterChannel(id NodeID) + Serialize() ([]byte, error) + LockState(write bool) + UnlockState(write bool) + Process(context *StateContext, signal GraphSignal) error + Policies() []Policy } -// A GraphNode is an implementation of a Node that can be embedded into more complex structures -type GraphNode struct { - sync.RWMutex - listeners_lock sync.Mutex - +type SimpleNode struct { id NodeID - listeners map[NodeID]chan GraphSignal + state_mutex sync.RWMutex policies map[NodeID]Policy } -type GraphNodeJSON struct { - Policies []string `json:"policies"` +func NewSimpleNode(id NodeID) SimpleNode { + return SimpleNode{ + id: id, + policies: map[NodeID]Policy{}, + } } -func (node * GraphNode) Policies() map[NodeID]Policy { - return node.policies +type SimpleNodeJSON struct { + Policies []string `json:"policies"` } -func (node * GraphNode) Serialize() ([]byte, error) { - node_json := NewGraphNodeJSON(node) - return json.MarshalIndent(&node_json, "", " ") +func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error { + context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %+v", node.id, signal) + return nil } -func Allowed(context *StateContext, policies map[NodeID]Policy, node Node, resource string, action string, princ Node) error { - if princ == nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) - return fmt.Errorf("nil is not allowed to perform any actions") - } - if node.ID() == princ.ID() { - return nil - } - for _, policy := range(policies) { - if policy.Allows(node, resource, action, princ) == true { - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) - return nil - } - } - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) - return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID()) +func (node *SimpleNode) ID() NodeID { + return node.id } -func (node *GraphNode) AddPolicy(policy Policy) error { - if policy == nil { - return fmt.Errorf("Cannot add nil as a policy") - } - - _, exists := node.policies[policy.ID()] - if exists == true { - return fmt.Errorf("%s is already a policy for %s", policy.ID().String(), node.ID().String()) - } +func (node *SimpleNode) Type() NodeType { + return NodeType("simple_node") +} - node.policies[policy.ID()] = policy - return nil +func (node *SimpleNode) Serialize() ([]byte, error) { + j := NewSimpleNodeJSON(node) + return json.MarshalIndent(&j, "", " ") } -func (node *GraphNode) RemovePolicy(policy Policy) error { - if policy == nil { - return fmt.Errorf("Cannot add nil as a policy") +func (node *SimpleNode) LockState(write bool) { + if write == true { + node.state_mutex.Lock() + } else { + node.state_mutex.RLock() } +} - _, exists := node.policies[policy.ID()] - if exists == false { - return fmt.Errorf("%s is not a policy for %s", policy.ID().String(), node.ID().String()) +func (node *SimpleNode) UnlockState(write bool) { + if write == true { + node.state_mutex.Unlock() + } else { + node.state_mutex.RUnlock() } - - delete(node.policies, policy.ID()) - return nil } -func NewGraphNodeJSON(node *GraphNode) GraphNodeJSON { - policies := make([]string, len(node.policies)) +func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON { + policy_ids := make([]string, len(node.policies)) i := 0 - for _, policy := range(node.policies) { - policies[i] = policy.ID().String() + for id, _ := range(node.policies) { + policy_ids[i] = id.String() i += 1 } - return GraphNodeJSON{ - Policies: policies, + + return SimpleNodeJSON{ + Policies: policy_ids, } } -func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) error { +func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes NodeMap) error { for _, policy_str := range(j.Policies) { policy_id, err := ParseID(policy_str) if err != nil { return err } + policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes) if err != nil { return err @@ -177,27 +144,60 @@ func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) e if ok == false { return fmt.Errorf("%s is not a Policy", policy_id) } - node.AddPolicy(policy) + node.policies[policy_id] = policy } + return nil } -func LoadGraphNode(ctx * Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { - if len(data) > 0 { - return nil, fmt.Errorf("Attempted to load a graph_node with data %+v, should have been 0 length", string(data)) +func LoadSimpleNode(ctx *Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { + var j SimpleNodeJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + node := NewSimpleNode(id) + nodes[id] = &node + + err = RestoreSimpleNode(ctx, &node, j, nodes) + if err != nil { + return nil, err } - node := NewGraphNode(id) + return &node, nil } -func (node * GraphNode) ID() NodeID { - return node.id +func (node *SimpleNode) Policies() []Policy { + ret := make([]Policy, len(node.policies)) + i := 0 + for _, policy := range(node.policies) { + ret[i] = policy + i += 1 + } + + return ret } -func (node * GraphNode) Type() NodeType { - return NodeType("graph_node") +func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error { + if princ == nil { + context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) + return fmt.Errorf("nil is not allowed to perform any actions") + } + if node.ID() == princ.ID() { + return nil + } + for _, policy := range(policies) { + if policy.Allows(node, resource, action, princ) == true { + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) + return nil + } + } + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) + return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID()) } + // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error { @@ -211,75 +211,19 @@ func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) er return nil } - return node.Process(context, princ, signal) -} - -func (node * GraphNode) Process(context *StateContext, princ Node, signal GraphSignal) error { - node.listeners_lock.Lock() - defer node.listeners_lock.Unlock() - closed := []NodeID{} - - for id, listener := range node.listeners { - context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %s", node.ID(), id) - select { - case listener <- signal: - default: - context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %s", node.ID(), id) - go func(node Node, listener chan GraphSignal) { - listener <- NewDirectSignal("listener_closed") - close(listener) - }(node, listener) - closed = append(closed, id) - } - } - - for _, id := range(closed) { - delete(node.listeners, id) - } - return nil + return node.Process(context, signal) } -func (node * GraphNode) RegisterChannel(id NodeID, listener chan GraphSignal) { - node.listeners_lock.Lock() - _, exists := node.listeners[id] - if exists == false { - node.listeners[id] = listener - } - node.listeners_lock.Unlock() -} - -func (node * GraphNode) UnregisterChannel(id NodeID) { - node.listeners_lock.Lock() - _, exists := node.listeners[id] - if exists == false { - panic("Attempting to unregister non-registered listener") - } else { - delete(node.listeners, id) - } - node.listeners_lock.Unlock() -} - -func AttachPolicies(ctx *Context, node Node, policies ...Policy) error { +func AttachPolicies(ctx *Context, node *SimpleNode, policies ...Policy) error { context := NewWriteContext(ctx) return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error { for _, policy := range(policies) { - err := node.AddPolicy(policy) - if err != nil { - return err - } + node.policies[policy.ID()] = policy } return nil }) } -func NewGraphNode(id NodeID) GraphNode { - return GraphNode{ - id: id, - listeners: map[NodeID]chan GraphSignal{}, - policies: map[NodeID]Policy{}, - } -} - // Magic first four bytes of serialized DB content, stored big endian const NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function @@ -458,6 +402,17 @@ func NewLockMap(requests ...LockMap) LockMap { return reqs } +func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap { + reqs := LockMap{} + for _, node := range(m) { + reqs[node.ID()] = LockInfo{ + Node: node, + Resources: resources, + } + } + return reqs +} + func LockList[K Node](list []K, resources []string) LockMap { reqs := LockMap{} for _, node := range(list) { @@ -565,7 +520,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St if princ_locked == false { new_locks = append(new_locks, princ) context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String()) - princ.RLock() + princ.LockState(false) } princ_permissions, princ_exists := context.Permissions[princ.ID()] @@ -588,7 +543,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St if locked == false { new_locks = append(new_locks, node) context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) - node.RLock() + node.LockState(false) } } @@ -610,7 +565,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) - n.RUnlock() + n.UnlockState(false) } return err } @@ -632,7 +587,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St for _, node := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String()) delete(context.Locked, node.ID()) - node.RUnlock() + node.UnlockState(false) } return err @@ -661,7 +616,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn if princ_locked == false { new_locks = append(new_locks, princ) context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String()) - princ.Lock() + princ.LockState(true) } princ_permissions, princ_exists := context.Permissions[princ.ID()] @@ -684,7 +639,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn if locked == false { new_locks = append(new_locks, node) context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) - node.Lock() + node.LockState(true) } } @@ -706,7 +661,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) - n.Unlock() + n.UnlockState(true) } return err } @@ -730,19 +685,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } for id, node := range(context.Locked) { context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) - node.Unlock() + node.UnlockState(true) } } return err } -// Create a new channel with a buffer the size of buffer, and register it to node with the id -func UpdateChannel(node Node, buffer int, id NodeID) chan GraphSignal { - if node == nil { - panic("Cannot get an update channel to nil") - } - new_listener := make(chan GraphSignal, buffer) - node.RegisterChannel(id, new_listener) - return new_listener -} diff --git a/policy.go b/policy.go index ba33020..904b91e 100644 --- a/policy.go +++ b/policy.go @@ -44,12 +44,12 @@ func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) Nod } type PerNodePolicy struct { - GraphNode + SimpleNode Actions map[NodeID]NodeActions } type PerNodePolicyJSON struct { - GraphNodeJSON + SimpleNodeJSON Actions map[string]map[string][]string `json:"actions"` } @@ -64,7 +64,7 @@ func (policy *PerNodePolicy) Serialize() ([]byte, error) { } return json.MarshalIndent(&PerNodePolicyJSON{ - GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), + SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode), Actions: actions, }, "", " ") } @@ -75,7 +75,7 @@ func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy { } return PerNodePolicy{ - GraphNode: NewGraphNode(id), + SimpleNode: NewSimpleNode(id), Actions: actions, } } @@ -100,7 +100,7 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod policy := NewPerNodePolicy(id, actions) nodes[id] = &policy - err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) + err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes) if err != nil { return nil, err } @@ -122,12 +122,12 @@ func (policy *PerNodePolicy) Allows(node Node, resource string, action string, p } type SimplePolicy struct { - GraphNode + SimpleNode Actions NodeActions } type SimplePolicyJSON struct { - GraphNodeJSON + SimpleNodeJSON Actions map[string][]string `json:"actions"` } @@ -137,7 +137,7 @@ func (policy *SimplePolicy) Type() NodeType { func (policy *SimplePolicy) Serialize() ([]byte, error) { return json.MarshalIndent(&SimplePolicyJSON{ - GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), + SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode), Actions: policy.Actions, }, "", " ") } @@ -148,7 +148,7 @@ func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy { } return SimplePolicy{ - GraphNode: NewGraphNode(id), + SimpleNode: NewSimpleNode(id), Actions: actions, } } @@ -163,7 +163,7 @@ func LoadSimplePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node policy := NewSimplePolicy(id, j.Actions) nodes[id] = &policy - err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) + err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes) if err != nil { return nil, err } @@ -176,12 +176,12 @@ func (policy *SimplePolicy) Allows(node Node, resource string, action string, pr } type PerTagPolicy struct { - GraphNode + SimpleNode Actions map[string]NodeActions } type PerTagPolicyJSON struct { - GraphNodeJSON + SimpleNodeJSON Actions map[string]map[string][]string `json:"json"` } @@ -196,7 +196,7 @@ func (policy *PerTagPolicy) Serialize() ([]byte, error) { } return json.MarshalIndent(&PerTagPolicyJSON{ - GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), + SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode), Actions: actions, }, "", " ") } @@ -207,7 +207,7 @@ func NewPerTagPolicy(id NodeID, actions map[string]NodeActions) PerTagPolicy { } return PerTagPolicy{ - GraphNode: NewGraphNode(id), + SimpleNode: NewSimpleNode(id), Actions: actions, } } @@ -227,7 +227,7 @@ func LoadPerTagPolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node policy := NewPerTagPolicy(id, actions) nodes[id] = &policy - err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) + err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes) if err != nil { return nil, err } @@ -268,12 +268,12 @@ func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy { } func (policy *DependencyPolicy) Allows(node Node, resource string, action string, principal Node) bool { - lockable, ok := node.(Lockable) + lockable, ok := node.(LockableNode) if ok == false { return false } - for _, dep := range(lockable.Dependencies()) { + for _, dep := range(lockable.LockableHandle().Dependencies) { if dep.ID() == principal.ID() { return policy.Actions.Allows(resource, action) } diff --git a/thread.go b/thread.go index 71ce97d..d0b49c6 100644 --- a/thread.go +++ b/thread.go @@ -5,13 +5,12 @@ import ( "time" "sync" "errors" - "reflect" "encoding/json" ) // Assumed that thread is already locked for signal -func (thread *SimpleThread) Process(context *StateContext, princ Node, signal GraphSignal) error { - err := thread.SimpleLockable.Process(context, princ, signal) +func (thread *Thread) Process(context *StateContext, signal GraphSignal) error { + err := thread.Lockable.Process(context, signal) if err != nil { return err } @@ -19,16 +18,16 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr switch signal.Direction() { case Up: err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { - if thread.parent != nil { - return Signal(context, thread.parent, thread, signal) + if thread.Parent != nil { + return Signal(context, thread.Parent, thread, signal) } else { return nil } }) case Down: err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { - for _, child := range(thread.children) { - err := Signal(context, child, thread, signal) + for _, info := range(thread.Children) { + err := Signal(context, info.Child, thread, signal) if err != nil { return err } @@ -44,136 +43,35 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr return err } - thread.signal <- signal + thread.Chan <- signal return nil } -// Interface to represent any type of thread information -type ThreadInfo interface { -} - -func (thread * SimpleThread) SetTimeout(timeout time.Time, action string) { - thread.timeout = timeout - thread.timeout_action = action - thread.timeout_chan = time.After(time.Until(timeout)) -} - -func (thread * SimpleThread) TimeoutAction() string { - return thread.timeout_action -} - -func (thread * SimpleThread) State() string { - return thread.state_name -} - -func (thread * SimpleThread) SetState(new_state string) error { - if new_state == "" { - return fmt.Errorf("Cannot set state to '' with SetState") - } - - thread.state_name = new_state - return nil -} - -func (thread * SimpleThread) Parent() Thread { - return thread.parent -} - -func (thread * SimpleThread) SetParent(parent Thread) { - thread.parent = parent -} - -func (thread * SimpleThread) Children() []Thread { - return thread.children -} - -func (thread * SimpleThread) Child(id NodeID) Thread { - for _, child := range(thread.children) { - if child.ID() == id { - return child - } - } - return nil -} - - - -func (thread * SimpleThread) ChildInfo(child NodeID) ThreadInfo { - return thread.child_info[child] -} - // Requires thread and childs thread to be locked for write -func UnlinkThreads(ctx * Context, thread Thread, child Thread) error { - var found Node = nil - for _, c := range(thread.Children()) { - if child.ID() == c.ID() { - found = c - break - } - } - - if found == nil { +func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error { + thread := node.ThreadHandle() + child := child_node.ThreadHandle() + _, is_child := thread.Children[child_node.ID()] + if is_child == false { return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) } - child.SetParent(nil) - thread.RemoveChild(child) + child.Parent = nil + delete(thread.Children, child.ID()) return nil } -func (thread * SimpleThread) RemoveChild(child Thread) { - idx := -1 - for i, c := range(thread.children) { - if c.ID() == child.ID() { - idx = i - break - } - } - - if idx == -1 { - panic(fmt.Sprintf("%s is not a child of %s", child.ID(), thread.Name())) - } - - child_len := len(thread.children) - thread.children[idx] = thread.children[child_len-1] - thread.children = thread.children[0:child_len-1] -} - -func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error { - if child == nil { - return fmt.Errorf("Will not connect nil to the thread tree") - } - - _, exists := thread.child_info[child.ID()] - if exists == true { - return fmt.Errorf("Will not connect the same child twice") - } - - if info == nil && thread.InfoType != nil { - return fmt.Errorf("nil info passed when expecting info") - } else if info != nil { - if reflect.TypeOf(info) != thread.InfoType { - return fmt.Errorf("info type mismatch, expecting %+v - %+v", thread.InfoType, reflect.TypeOf(info)) - } - } - - thread.children = append(thread.children, child) - thread.child_info[child.ID()] = info - - return nil -} - -func checkIfChild(context *StateContext, target Thread, cur Thread) bool { - for _, child := range(cur.Children()) { - if child.ID() == target.ID() { +func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool { + for _, info := range(cur.ThreadHandle().Children) { + if info.Child.ID() == target.ID() { return true } is_child := false UpdateStates(context, cur, NewLockMap( - NewLockInfo(child, []string{"children"}), + NewLockInfo(info.Child, []string{"children"}), ), func(context *StateContext) error { - is_child = checkIfChild(context, target, child) + is_child = checkIfChild(context, target, info.Child) return nil }) if is_child { @@ -186,10 +84,12 @@ func checkIfChild(context *StateContext, target Thread, cur Thread) bool { // Links child to parent with info as the associated info // Continues the write context with princ, getting children for thread and parent for child -func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { - if context == nil || thread == nil || child == nil { +func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error { + if context == nil || thread_node == nil || info.Child == nil { return fmt.Errorf("invalid input") } + thread := thread_node.ThreadHandle() + child := info.Child.ThreadHandle() if thread.ID() == child.ID() { return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) @@ -199,7 +99,7 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, child.ID(): LockInfo{Node: child, Resources: []string{"parent"}}, thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}}, }, func(context *StateContext) error { - if child.Parent() != nil { + if child.Parent != nil { return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) } @@ -211,60 +111,23 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) } - err := thread.AddChild(child, info) - if err != nil { - return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) - } - child.SetParent(thread) + // TODO check for info types - if err != nil { - return err - } + thread.Children[child.ID()] = info + child.Parent = thread_node return nil }) } -type ThreadAction func(* Context, Thread)(string, error) +type ThreadAction func(*Context, ThreadNode)(string, error) type ThreadActions map[string]ThreadAction -type ThreadHandler func(* Context, Thread, GraphSignal)(string, error) +type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error) type ThreadHandlers map[string]ThreadHandler -type Thread interface { - // All Threads are Lockables - Lockable - /// State Modification Functions - SetParent(parent Thread) - AddChild(child Thread, info ThreadInfo) error - RemoveChild(child Thread) - SetState(new_thread string) error - SetTimeout(end_time time.Time, action string) - /// State Reading Functions - Parent() Thread - Children() []Thread - Child(id NodeID) Thread - ChildInfo(child NodeID) ThreadInfo - State() string - TimeoutAction() string - - /// Functions that dont read/write thread - // Deserialize the attribute map from json.Unmarshal - DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) - SetActive(active bool) error - Action(action string) (ThreadAction, bool) - Handler(signal_type string) (ThreadHandler, bool) - - // Internal timeout channel for thread - Timeout() <-chan time.Time - // Internal signal channel for thread - SignalChannel() <-chan GraphSignal - ClearTimeout() - - ChildWaits() *sync.WaitGroup -} - -type ParentInfo interface { - Parent() *ParentThreadInfo +type InfoType string +func (t InfoType) String() string { + return string(t) } // Data required by a parent thread to restore it's children @@ -274,10 +137,6 @@ type ParentThreadInfo struct { RestoreAction string `json:"restore_action"` } -func (info * ParentThreadInfo) Parent() *ParentThreadInfo{ - return info -} - func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo { return ParentThreadInfo{ Start: start, @@ -286,83 +145,123 @@ func NewParentThreadInfo(start bool, start_action string, restore_action string) } } -type SimpleThread struct { - SimpleLockable +type ChildInfo struct { + Child ThreadNode + Infos map[InfoType]interface{} +} - actions ThreadActions - handlers ThreadHandlers +func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo { + if infos == nil { + infos = map[InfoType]interface{}{} + } - timeout_chan <-chan time.Time - signal chan GraphSignal - child_waits *sync.WaitGroup - active bool - active_lock *sync.Mutex + return ChildInfo{ + Child: child, + Infos: infos, + } +} + +type QueuedAction struct { + Timeout time.Time + Action string +} + +type ThreadNode interface { + LockableNode + ThreadHandle() *Thread +} + +type Thread struct { + Lockable + + Actions ThreadActions + Handlers ThreadHandlers + + TimeoutChan <-chan time.Time + Chan chan GraphSignal + ChildWaits sync.WaitGroup + Active bool + ActiveLock sync.Mutex + + StateName string + Parent ThreadNode + Children map[NodeID]ChildInfo + InfoTypes []InfoType + TimeoutAction string + Timeout time.Time - state_name string - parent Thread - children []Thread - child_info map[NodeID] ThreadInfo - InfoType reflect.Type - timeout time.Time - timeout_action string } -func (thread * SimpleThread) Type() NodeType { +func (thread *Thread) ThreadHandle() *Thread { + return thread +} + +func (thread *Thread) Type() NodeType { return NodeType("simple_thread") } -func (thread * SimpleThread) Serialize() ([]byte, error) { - thread_json := NewSimpleThreadJSON(thread) +func (thread *Thread) Serialize() ([]byte, error) { + thread_json := NewThreadJSON(thread) return json.MarshalIndent(&thread_json, "", " ") } -func (thread * SimpleThread) SignalChannel() <-chan GraphSignal { - return thread.signal +func (thread *Thread) ChildList() []ThreadNode { + ret := make([]ThreadNode, len(thread.Children)) + i := 0 + for _, info := range(thread.Children) { + ret[i] = info.Child + i += 1 + } + return ret } -type SimpleThreadJSON struct { +type ThreadJSON struct { Parent string `json:"parent"` - Children map[string]interface{} `json:"children"` + Children map[string]map[string]interface{} `json:"children"` Timeout time.Time `json:"timeout"` TimeoutAction string `json:"timeout_action"` StateName string `json:"state_name"` - SimpleLockableJSON + LockableJSON } -func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON { - children := map[string]interface{}{} - for _, child := range(thread.children) { - children[child.ID().String()] = thread.child_info[child.ID()] +func NewThreadJSON(thread *Thread) ThreadJSON { + children := map[string]map[string]interface{}{} + for id, info := range(thread.Children) { + tmp := map[string]interface{}{} + for name, i := range(info.Infos) { + tmp[name.String()] = i + } + children[id.String()] = tmp } parent_id := "" - if thread.parent != nil { - parent_id = thread.parent.ID().String() + if thread.Parent != nil { + parent_id = thread.Parent.ID().String() } - lockable_json := NewSimpleLockableJSON(&thread.SimpleLockable) + lockable_json := NewLockableJSON(&thread.Lockable) - return SimpleThreadJSON{ + return ThreadJSON{ Parent: parent_id, Children: children, - Timeout: thread.timeout, - TimeoutAction: thread.timeout_action, - StateName: thread.state_name, - SimpleLockableJSON: lockable_json, + Timeout: thread.Timeout, + TimeoutAction: thread.TimeoutAction, + StateName: thread.StateName, + LockableJSON: lockable_json, } } -func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j SimpleThreadJSON +func LoadThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j ThreadJSON err := json.Unmarshal(data, &j) if err != nil { return nil, err } - thread := NewSimpleThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers) + thread := NewThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers) nodes[id] = &thread - err = RestoreSimpleThread(ctx, &thread, j, nodes) + err = RestoreThread(ctx, &thread, j, nodes) if err != nil { return nil, err } @@ -370,17 +269,10 @@ func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &thread, nil } -// SimpleThread has no associated info with children -func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) { - if len(data) > 0 { - return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data)) - } - return nil, nil -} - -func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes NodeMap) error { +func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) error { if j.TimeoutAction != "" { - thread.SetTimeout(j.Timeout, j.TimeoutAction) + thread.Timeout = j.Timeout + thread.TimeoutAction = j.TimeoutAction } if j.Parent != "" { @@ -392,11 +284,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes if err != nil { return err } - p_t, ok := p.(Thread) + p_t, ok := p.(ThreadNode) if ok == false { return err } - thread.SetParent(p_t) + thread.Parent = p_t } for id_str, info_raw := range(j.Children) { @@ -404,63 +296,94 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes if err != nil { return err } + child_node, err := LoadNodeRecurse(ctx, id, nodes) if err != nil { return err } - child_t, ok := child_node.(Thread) + + child_t, ok := child_node.(ThreadNode) if ok == false { return fmt.Errorf("%+v is not a Thread as expected", child_node) } - var info_ser []byte - if info_raw != nil { - info_ser, err = json.Marshal(info_raw) - if err != nil { - return err - } - } - - parsed_info, err := thread.DeserializeInfo(ctx, info_ser) + parsed_info, err := DeserializeChildInfo(ctx, info_raw) if err != nil { return err } - thread.AddChild(child_t, parsed_info) + thread.Children[id] = ChildInfo{child_t, parsed_info} + } + + return RestoreLockable(ctx, &thread.Lockable, j.LockableJSON, nodes) +} + +var deserializers = map[InfoType]func(interface{})(interface{}, error) { + +} + +func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) { + ret := map[InfoType]interface{}{} + for type_str, info_raw := range(infos_raw) { + info_type := InfoType(type_str) + deserializer, exists := deserializers[info_type] + if exists == false { + return nil, fmt.Errorf("No deserializer for %s", info_type) + } + var err error + ret[info_type], err = deserializer(info_raw) + if err != nil { + return nil, err + } } - return RestoreSimpleLockable(ctx, thread, j.SimpleLockableJSON, nodes) + return ret, nil } const THREAD_SIGNAL_BUFFER_SIZE = 128 -func NewSimpleThread(id NodeID, name string, state_name string, info_type reflect.Type, actions ThreadActions, handlers ThreadHandlers) SimpleThread { - return SimpleThread{ - SimpleLockable: NewSimpleLockable(id, name), - InfoType: info_type, - state_name: state_name, - signal: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE), - children: []Thread{}, - child_info: map[NodeID]ThreadInfo{}, - actions: actions, - handlers: handlers, - child_waits: &sync.WaitGroup{}, - active_lock: &sync.Mutex{}, +func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread { + return Thread{ + Lockable: NewLockable(id, name), + InfoTypes: info_types, + StateName: state_name, + Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE), + Children: map[NodeID]ChildInfo{}, + Actions: actions, + Handlers: handlers, + } +} + +func (thread *Thread) SetActive(active bool) error { + thread.ActiveLock.Lock() + defer thread.ActiveLock.Unlock() + if thread.Active == true && active == true { + return fmt.Errorf("%s is active, cannot set active", thread.ID()) + } else if thread.Active == false && active == false { + return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID()) } + thread.Active = active + return nil +} + +func (thread *Thread) SetState(state string) error { + thread.StateName = state + return nil } // Requires the read permission of threads children -func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread { - if thread == nil { +func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode { + if node == nil { panic("cannot recurse through nil") } + thread := node.ThreadHandle() if id == thread.ID() { return thread } - for _, child := range thread.Children() { - var result Thread = nil - UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error { - result = FindChild(context, princ, child, id) + for _, info := range thread.Children { + var result ThreadNode + UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error { + result = FindChild(context, princ, info.Child, id) return nil }) if result != nil { @@ -471,11 +394,11 @@ func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thre return nil } -func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) { - thread.ChildWaits().Add(1) - go func(child Thread) { +func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) { + thread.ChildWaits.Add(1) + go func(child ThreadNode) { ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) - defer thread.ChildWaits().Done() + defer thread.ChildWaits.Done() err := ThreadLoop(ctx, child, first_action) if err != nil { ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err) @@ -486,8 +409,9 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) { } // Main Loop for Threads, starts a write context, so cannot be called from a write or read context -func ThreadLoop(ctx * Context, thread Thread, first_action string) error { +func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error { // Start the thread, error if double-started + thread := node.ThreadHandle() ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) err := thread.SetActive(true) if err != nil { @@ -496,14 +420,14 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error { } next_action := first_action for next_action != "" { - action, exists := thread.Action(next_action) + action, exists := thread.Actions[next_action] if exists == false { error_str := fmt.Sprintf("%s is not a valid action", next_action) return errors.New(error_str) } ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action) - next_action, err = action(ctx, thread) + next_action, err = action(ctx, node) if err != nil { return err } @@ -523,52 +447,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error { return nil } - -func (thread * SimpleThread) ChildWaits() *sync.WaitGroup { - return thread.child_waits -} - -func (thread * SimpleThread) SetActive(active bool) error { - thread.active_lock.Lock() - defer thread.active_lock.Unlock() - if thread.active == true && active == true { - return fmt.Errorf("%s is active, cannot set active", thread.ID()) - } else if thread.active == false && active == false { - return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID()) - } - thread.active = active - return nil -} - -func (thread * SimpleThread) Action(action string) (ThreadAction, bool) { - action_fn, exists := thread.actions[action] - return action_fn, exists -} - -func (thread * SimpleThread) Handler(signal_type string) (ThreadHandler, bool) { - handler, exists := thread.handlers[signal_type] - return handler, exists -} - -func (thread * SimpleThread) Timeout() <-chan time.Time { - return thread.timeout_chan -} - -func (thread * SimpleThread) ClearTimeout() { - thread.timeout_chan = nil - thread.timeout_action = "" -} - -func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool { - for _, child := range(thread.children) { - if new_owner.ID() == child.ID() { - return true - } - } - return false -} - -func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (string, error) { +func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { + thread := node.ThreadHandle() ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal) context := NewWriteContext(ctx) err := UpdateStates(context, thread, NewLockMap( @@ -579,18 +459,18 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST") return nil } - info_if := thread.ChildInfo(sig.ID) - if info_if == nil { + info, exists := thread.Children[sig.ID] + if exists == false { ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) return nil } - info_r, correct := info_if.(ParentInfo) - if correct == false { - ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_INFO_CAST") + parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) + if exists == false { + panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul") } - info := info_r.Parent() - if info.Start == true { - ChildGo(ctx, thread, thread.Child(sig.ID), info.StartAction) + + if parent_info.Start == true { + ChildGo(ctx, thread, info.Child, parent_info.StartAction) } return nil }) @@ -603,38 +483,30 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s return "wait", nil } -func ThreadParentStartChild(ctx *Context, thread Thread, signal GraphSignal) (string, error) { - ctx.Log.Logf("thread", "THREAD_START_CHILD") +// Helper function to start a child from a thread during a signal handler +// Starts a write context, so cannot be called from either a write or read context +func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { sig, ok := signal.(StartChildSignal) if ok == false { - ctx.Log.Logf("thread", "THREAD_START_CHILD_BAD_SIGNAL: %+v", signal) return "wait", nil } - err := ThreadStartChild(ctx, thread, sig) - if err != nil { - ctx.Log.Logf("thread", "THREAD_START_CHILD_ERR: %s", err) - } else { - ctx.Log.Logf("thread", "THREAD_START_CHILD: %s", sig.ID.String()) - } - - return "wait", nil -} + thread := node.ThreadHandle() -// Helper function to start a child from a thread during a signal handler -// Starts a write context, so cannot be called from either a write or read context -func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error { context := NewWriteContext(ctx) - return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { - child := thread.Child(signal.ID) - if child == nil { - return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) + return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { + info, exists:= thread.Children[sig.ID] + if exists == false { + return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID()) } - return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error { + return UpdateStates(context, thread, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error { - info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) - info.Start = true - ChildGo(ctx, thread, child, signal.Action) + parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) + if exists == false { + return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") + } + parent_info.Start = true + ChildGo(ctx, thread, info.Child, sig.Action) return nil }) @@ -643,18 +515,19 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro // Helper function to restore threads that should be running from a parents restore action // Starts a write context, so cannot be called from either a write or read context -func ThreadRestore(ctx * Context, thread Thread, start bool) error { +func ThreadRestore(ctx * Context, node ThreadNode, start bool) error { + thread := node.ThreadHandle() context := NewWriteContext(ctx) return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { - return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error { - for _, child := range(thread.Children()) { - info := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() - if info.Start == true && child.State() != "finished" { - ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID()) + return UpdateStates(context, thread, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error { + for _, info := range(thread.Children) { + parent_info := info.Infos["parent"].(*ParentThreadInfo) + if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" { + ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID()) if start == true { - ChildGo(ctx, thread, child, info.StartAction) + ChildGo(ctx, thread, info.Child, parent_info.StartAction) } else { - ChildGo(ctx, thread, child, info.RestoreAction) + ChildGo(ctx, thread, info.Child, parent_info.RestoreAction) } } } @@ -665,10 +538,12 @@ func ThreadRestore(ctx * Context, thread Thread, start bool) error { // Helper function to be called during a threads start action, sets the thread state to started // Starts a write context, so cannot be called from either a write or read context -func ThreadStart(ctx * Context, thread Thread) error { +// Returns "wait", nil on success, so the first return value can be ignored safely +func ThreadStart(ctx * Context, node ThreadNode) (string, error) { + thread := node.ThreadHandle() context := NewWriteContext(ctx) - return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { - err := LockLockables(context, []Lockable{thread}, thread) + return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { + err := LockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread) if err != nil { return err } @@ -676,39 +551,28 @@ func ThreadStart(ctx * Context, thread Thread) error { }) } -func ThreadDefaultStart(ctx * Context, thread Thread) (string, error) { - ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID()) - err := ThreadStart(ctx, thread) - if err != nil { - return "", err - } - return "wait", nil -} - -func ThreadDefaultRestore(ctx * Context, thread Thread) (string, error) { - ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID()) - return "wait", nil -} - -func ThreadWait(ctx * Context, thread Thread) (string, error) { - ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) +func ThreadWait(ctx * Context, node ThreadNode) (string, error) { + thread := node.ThreadHandle() + ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout) for { select { - case signal := <- thread.SignalChannel(): + case signal := <- thread.Chan: ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) - signal_fn, exists := thread.Handler(signal.Type()) + signal_fn, exists := thread.Handlers[signal.Type()] if exists == true { ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) return signal_fn(ctx, thread, signal) } else { ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type()) } - case <- thread.Timeout(): + case <- thread.TimeoutChan: timeout_action := "" context := NewWriteContext(ctx) err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error { - timeout_action = thread.TimeoutAction() - thread.ClearTimeout() + timeout_action = thread.TimeoutAction + thread.TimeoutChan = nil + thread.TimeoutAction = "" + thread.Timeout = time.Time{} return nil }) if err != nil { @@ -720,26 +584,23 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) { } } -func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) { - ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String()) - return "", ThreadFinish(ctx, thread) -} - -func ThreadFinish(ctx *Context, thread Thread) error { +func ThreadFinish(ctx *Context, node ThreadNode) (string, error) { + thread := node.ThreadHandle() context := NewWriteContext(ctx) - return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { + return "", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { err := thread.SetState("finished") if err != nil { return err } - return UnlockLockables(context, []Lockable{thread}, thread) + return UnlockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread) }) } var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError -func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { +func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { + thread := node.ThreadHandle() context := NewReadContext(ctx) err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID())) if err != nil { @@ -749,38 +610,18 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro } // Default thread action for "stop", sends a signal and returns no error -func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { +func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { + thread := node.ThreadHandle() context := NewReadContext(ctx) err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID())) return "finish", err } -// Copy the default thread actions to a new ThreadActions map -func NewThreadActions() ThreadActions{ - actions := ThreadActions{} - for k, v := range(BaseThreadActions) { - actions[k] = v - } - - return actions -} - -// Copy the defult thread handlers to a new ThreadAction map -func NewThreadHandlers() ThreadHandlers{ - handlers := ThreadHandlers{} - for k, v := range(BaseThreadHandlers) { - handlers[k] = v - } - - return handlers -} - // Default thread actions var BaseThreadActions = ThreadActions{ "wait": ThreadWait, - "start": ThreadDefaultStart, - "finish": ThreadDefaultFinish, - "restore": ThreadDefaultRestore, + "start": ThreadStart, + "finish": ThreadFinish, } // Default thread signal handlers diff --git a/user.go b/user.go index 7c3db43..6eeeb35 100644 --- a/user.go +++ b/user.go @@ -9,7 +9,7 @@ import ( ) type User struct { - SimpleLockable + Lockable Granted time.Time Pubkey *ecdsa.PublicKey @@ -18,7 +18,7 @@ type User struct { } type UserJSON struct { - SimpleLockableJSON + LockableJSON Granted time.Time `json:"granted"` Pubkey []byte `json:"pubkey"` Shared []byte `json:"shared"` @@ -30,14 +30,14 @@ func (user *User) Type() NodeType { } func (user *User) Serialize() ([]byte, error) { - lockable_json := NewSimpleLockableJSON(&user.SimpleLockable) + lockable_json := NewLockableJSON(&user.Lockable) pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey) if err != nil { return nil, err } return json.MarshalIndent(&UserJSON{ - SimpleLockableJSON: lockable_json, + LockableJSON: lockable_json, Granted: user.Granted, Shared: user.Shared, Pubkey: pubkey, @@ -68,7 +68,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) user := NewUser(j.Name, j.Granted, pubkey, j.Shared, j.Tags) nodes[id] = &user - err = RestoreSimpleLockable(ctx, &user, j.SimpleLockableJSON, nodes) + err = RestoreLockable(ctx, &user.Lockable, j.LockableJSON, nodes) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte, tags []string) User { id := KeyID(pubkey) return User{ - SimpleLockable: NewSimpleLockable(id, name), + Lockable: NewLockable(id, name), Granted: granted, Pubkey: pubkey, Shared: shared,