Tests compile and run

graph-rework-2
noah metz 2023-07-24 16:04:56 -06:00
parent 201ee7234b
commit fc2e36043f
12 changed files with 713 additions and 983 deletions

@ -86,8 +86,8 @@ func (ctx * Context) RegisterNodeType(def NodeDef) error {
ctx.Types[type_hash] = def ctx.Types[type_hash] = def
node_type := reflect.TypeOf((*Node)(nil)).Elem() node_type := reflect.TypeOf((*Node)(nil)).Elem()
lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem()
thread_type := reflect.TypeOf((*Thread)(nil)).Elem() thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem()
if def.Reflect.Implements(node_type) { if def.Reflect.Implements(node_type) {
ctx.GQL.ValidNodes[def.Reflect] = def.GQLType ctx.GQL.ValidNodes[def.Reflect] = def.GQLType
@ -154,7 +154,7 @@ func NewGQLContext() GQLContext {
Query: query, Query: query,
Mutation: mutation, Mutation: mutation,
Subscription: subscription, Subscription: subscription,
BaseNodeType: GQLTypeGraphNode.Type, BaseNodeType: GQLTypeSimpleNode.Type,
BaseLockableType: GQLTypeSimpleLockable.Type, BaseLockableType: GQLTypeSimpleLockable.Type,
BaseThreadType: GQLTypeSimpleThread.Type, BaseThreadType: GQLTypeSimpleThread.Type,
} }
@ -171,15 +171,19 @@ func NewContext(db * badger.DB, log Logger) * Context {
Types: map[uint64]NodeDef{}, Types: map[uint64]NodeDef{},
} }
err := ctx.RegisterNodeType(NewNodeDef((*GraphNode)(nil), LoadGraphNode, GQLTypeGraphNode.Type)) err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*SimpleLockable)(nil), LoadSimpleLockable, GQLTypeSimpleLockable.Type)) err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*SimpleThread)(nil), LoadSimpleThread, GQLTypeSimpleThread.Type)) err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -191,19 +195,19 @@ func NewContext(db * badger.DB, log Logger) * Context {
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeGraphNode.Type)) err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type)) err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeGraphNode.Type)) err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeSimpleNode.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type)) err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil { if err != nil {
panic(err) panic(err)
} }

126
gql.go

@ -629,7 +629,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
} }
type GQLThread struct { type GQLThread struct {
SimpleThread Thread
tcp_listener net.Listener tcp_listener net.Listener
http_server *http.Server http_server *http.Server
http_done *sync.WaitGroup http_done *sync.WaitGroup
@ -639,6 +639,37 @@ type GQLThread struct {
Users map[NodeID]*User Users map[NodeID]*User
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
ECDH ecdh.Curve ECDH ecdh.Curve
SubscribeLock sync.Mutex
SubscribeListeners []chan GraphSignal
}
func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal {
thread.SubscribeLock.Lock()
defer thread.SubscribeLock.Unlock()
new_listener := make(chan GraphSignal, buffer)
thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener)
return new_listener
}
func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error {
active_listeners := []chan GraphSignal{}
thread.SubscribeLock.Lock()
for _, listener := range(thread.SubscribeListeners) {
select {
case listener <- signal:
active_listeners = append(active_listeners, listener)
default:
go func(listener chan GraphSignal) {
listener <- NewDirectSignal("Channel Closed")
close(listener)
}(listener)
}
}
thread.SubscribeListeners = active_listeners
thread.SubscribeLock.Unlock()
return thread.Thread.Process(context, signal)
} }
func (thread * GQLThread) Type() NodeType { func (thread * GQLThread) Type() NodeType {
@ -650,17 +681,8 @@ func (thread * GQLThread) Serialize() ([]byte, error) {
return json.MarshalIndent(&thread_json, "", " ") return json.MarshalIndent(&thread_json, "", " ")
} }
func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) {
var info ParentThreadInfo
err := json.Unmarshal(data, &info)
if err != nil {
return nil, err
}
return &info, nil
}
type GQLThreadJSON struct { type GQLThreadJSON struct {
SimpleThreadJSON ThreadJSON
Listen string `json:"listen"` Listen string `json:"listen"`
Users []string `json:"users"` Users []string `json:"users"`
Key []byte `json:"key"` Key []byte `json:"key"`
@ -686,7 +708,7 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
} }
func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
thread_json := NewSimpleThreadJSON(&thread.SimpleThread) thread_json := NewThreadJSON(&thread.Thread)
ser_key, err := x509.MarshalECPrivateKey(thread.Key) ser_key, err := x509.MarshalECPrivateKey(thread.Key)
if err != nil { if err != nil {
@ -701,7 +723,7 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
} }
return GQLThreadJSON{ return GQLThreadJSON{
SimpleThreadJSON: thread_json, ThreadJSON: thread_json,
Listen: thread.Listen, Listen: thread.Listen,
Users: users, Users: users,
Key: ser_key, Key: ser_key,
@ -744,7 +766,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
thread.Users[id] = user.(*User) thread.Users[id] = user.(*User)
} }
err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -793,8 +815,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
tls_key = ssl_key_pem tls_key = ssl_key_pem
} }
return GQLThread{ return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers),
Listen: listen, Listen: listen,
SubscribeListeners: []chan GraphSignal{},
Users: map[NodeID]*User{}, Users: map[NodeID]*User{},
http_done: &sync.WaitGroup{}, http_done: &sync.WaitGroup{},
Key: key, Key: key,
@ -806,40 +829,23 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
var gql_actions ThreadActions = ThreadActions{ var gql_actions ThreadActions = ThreadActions{
"wait": ThreadWait, "wait": ThreadWait,
"restore": func(ctx * Context, thread Thread) (string, error) { "restore": func(ctx *Context, node ThreadNode) (string, error) {
ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID()) return "start_server", ThreadRestore(ctx, node, false)
// Restore all the threads that have "Start" as true and arent in the "finished" state
err := ThreadRestore(ctx, thread, false)
if err != nil {
return "", err
}
return "start_server", nil
}, },
"start": func(ctx * Context, thread Thread) (string, error) { "start": func(ctx * Context, node ThreadNode) (string, error) {
ctx.Log.Logf("gql", "GQL_START") _, err := ThreadStart(ctx, node)
err := ThreadStart(ctx, thread)
if err != nil { if err != nil {
return "", err return "", err
} }
// Start all the threads that have "Start" as true and arent in the "finished" state return "start_server", ThreadRestore(ctx, node, true)
err = ThreadRestore(ctx, thread, true)
if err != nil {
return "", err
}
return "start_server", nil
}, },
"start_server": func(ctx * Context, thread Thread) (string, error) { "start_server": func(ctx * Context, node ThreadNode) (string, error) {
server, ok := thread.(*GQLThread) gql_thread := node.(*GQLThread)
if ok == false {
return "", fmt.Errorf("GQL_THREAD_START: %s is not GQLThread, %+v", thread.ID(), thread.State())
}
ctx.Log.Logf("gql", "GQL_START_SERVER")
// Serve the GQL http and ws handlers
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, server)) mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread))
mux.HandleFunc("/gql", GQLHandler(ctx, server)) mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread))
// Server a graphiql interface(TODO make configurable whether to start this) // Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler()) mux.HandleFunc("/graphiql", GraphiQLHandler())
@ -849,7 +855,7 @@ var gql_actions ThreadActions = ThreadActions{
mux.Handle("/site/", http.StripPrefix("/site", fs)) mux.Handle("/site/", http.StripPrefix("/site", fs))
http_server := &http.Server{ http_server := &http.Server{
Addr: server.Listen, Addr: gql_thread.Listen,
Handler: mux, Handler: mux,
} }
@ -858,7 +864,7 @@ var gql_actions ThreadActions = ThreadActions{
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
} }
cert, err := tls.X509KeyPair(server.tls_cert, server.tls_key) cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -870,23 +876,23 @@ var gql_actions ThreadActions = ThreadActions{
listener := tls.NewListener(l, &config) listener := tls.NewListener(l, &config)
server.http_done.Add(1) gql_thread.http_done.Add(1)
go func(server *GQLThread) { go func(gql_thread *GQLThread) {
defer server.http_done.Done() defer gql_thread.http_done.Done()
err := http_server.Serve(listener) err := http_server.Serve(listener)
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err)) panic(fmt.Sprintf("Failed to start gql server: %s", err))
} }
}(server) }(gql_thread)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, server, NewLockMap( err = UpdateStates(context, gql_thread, NewLockMap(
NewLockInfo(server, []string{"http_server"}), NewLockInfo(gql_thread, []string{"http_server"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
server.tcp_listener = listener gql_thread.tcp_listener = listener
server.http_server = http_server gql_thread.http_server = http_server
return nil return nil
}) })
@ -895,24 +901,24 @@ var gql_actions ThreadActions = ThreadActions{
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, server, server, NewStatusSignal("server_started", server.ID())) err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID()))
if err != nil { if err != nil {
return "", err return "", err
} }
return "wait", nil return "wait", nil
}, },
"finish": func(ctx *Context, thread Thread) (string, error) { "finish": func(ctx *Context, node ThreadNode) (string, error) {
server := thread.(*GQLThread) gql_thread := node.(*GQLThread)
server.http_server.Shutdown(context.TODO()) gql_thread.http_server.Shutdown(context.TODO())
server.http_done.Wait() gql_thread.http_done.Wait()
return "", ThreadFinish(ctx, thread) return ThreadFinish(ctx, node)
}, },
} }
var gql_handlers ThreadHandlers = ThreadHandlers{ var gql_handlers ThreadHandlers = ThreadHandlers{
"child_linked": ThreadParentChildLinked, "child_linked": ThreadChildLinked,
"start_child": ThreadParentStartChild, "start_child": ThreadStartChild,
"abort": ThreadAbort, "abort": ThreadAbort,
"stop": ThreadStop, "stop": ThreadStop,
} }

@ -28,7 +28,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
err = UseStates(context, ctx.User, NewLockMap( err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}), NewLockInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) (error){ ), func(context *StateContext) (error){
node = FindChild(context, ctx.User, ctx.Server, id) node = FindChild(context, ctx.User, &ctx.Server.Thread, id)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
@ -86,13 +86,13 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
err = UseStates(context, ctx.User, NewLockMap( err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}), NewLockInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
node := FindChild(context, ctx.User, ctx.Server, parent_id) parent := FindChild(context, ctx.User, &ctx.Server.Thread, parent_id)
if node == nil { if parent == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID())
} }
signal = NewStartChildSignal(child_id, action) signal = NewStartChildSignal(child_id, action)
return Signal(context, node, ctx.User, signal) return Signal(context, ctx.User, parent, signal)
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -105,15 +105,15 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Thread) node, ok := p.Source.(*Thread)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread") return nil, fmt.Errorf("Failed to cast source to Thread")
} }
var parent Thread = nil var parent ThreadNode = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.Parent() parent = node.ThreadHandle().Parent
return nil return nil
}) })
@ -130,7 +130,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Thread) node, ok := p.Source.(ThreadNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread") return nil, fmt.Errorf("Failed to cast source to Thread")
} }
@ -138,7 +138,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
var state string var state string
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.State() state = node.ThreadHandle().StateName
return nil return nil
}) })
@ -155,15 +155,20 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Thread) node, ok := p.Source.(ThreadNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread") return nil, fmt.Errorf("Failed to cast source to Thread")
} }
var children []Thread = nil var children []ThreadNode = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = node.Children() children = make([]ThreadNode, len(node.ThreadHandle().Children))
i := 0
for _, info := range(node.ThreadHandle().Children) {
children[i] = info.Child
i += 1
}
return nil return nil
}) })
@ -180,7 +185,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Lockable) node, ok := p.Source.(LockableNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, fmt.Errorf("Failed to cast source to Lockable")
} }
@ -188,7 +193,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
name := "" name := ""
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error {
name = node.Name() name = node.LockableHandle().Name
return nil return nil
}) })
@ -205,15 +210,20 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Lockable) node, ok := p.Source.(LockableNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, fmt.Errorf("Failed to cast source to Lockable")
} }
var requirements []Lockable = nil var requirements []LockableNode = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = node.Requirements() requirements = make([]LockableNode, len(node.LockableHandle().Requirements))
i := 0
for _, req := range(node.LockableHandle().Requirements) {
requirements[i] = req
i += 1
}
return nil return nil
}) })
@ -230,15 +240,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Lockable) node, ok := p.Source.(LockableNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, fmt.Errorf("Failed to cast source to Lockable")
} }
var dependencies []Lockable = nil var dependencies []LockableNode = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = node.Dependencies() dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies))
i := 0
for _, dep := range(node.LockableHandle().Dependencies) {
dependencies[i] = dep
i += 1
}
return nil return nil
}) })
@ -255,7 +270,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
node, ok := p.Source.(Lockable) node, ok := p.Source.(LockableNode)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, fmt.Errorf("Failed to cast source to Lockable")
} }
@ -263,7 +278,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
var owner Node = nil var owner Node = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.Owner() owner = node.LockableHandle().Owner
return nil return nil
}) })

@ -24,7 +24,7 @@ func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *G
c := make(chan interface{}) c := make(chan interface{})
go func(c chan interface{}, server *GQLThread) { go func(c chan interface{}, server *GQLThread) {
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START") ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
sig_c := UpdateChannel(server, 1, RandID()) sig_c := server.NewSubscriptionChannel(1)
if send_nil == true { if send_nil == true {
sig_c <- nil sig_c <- nil
} }

@ -20,46 +20,37 @@ import (
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"}) ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := NewListener(RandID(), "Test Lockable 1")
l1 := &l1_r
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String()) ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String())
listen_id := RandID() listen_id := RandID()
ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String()) ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String())
update_channel := UpdateChannel(t1, 10, listen_id)
u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
u1_shared := []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67} u1 := NewUser("Test User", time.Now(), &u1_key.PublicKey, []byte{}, []string{"gql"})
u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"})
u1 := &u1_r
ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String()) ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String())
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql := &gql_r
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String()) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String())
// Policy to allow gql to perform all action on all resources // Policy to allow gql to perform all action on all resources
p1_r := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ p1 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
gql.ID(): NewNodeActions(nil, []string{"*"}), gql.ID(): NewNodeActions(nil, []string{"*"}),
}) })
p1 := &p1_r p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
p2_r := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
"signal": []string{"status"}, "signal": []string{"status"},
}, nil)) }, nil))
p2 := &p2_r
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, gql, LockMap{ err = UpdateStates(context, &gql, LockMap{
p1.ID(): LockInfo{p1, nil}, p1.ID(): LockInfo{&p1, nil},
p2.ID(): LockInfo{p2, nil}, p2.ID(): LockInfo{&p2, nil},
}, func(context *StateContext) error { }, func(context *StateContext) error {
return nil return nil
}) })
@ -67,46 +58,48 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String()) ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String())
ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String()) ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String())
err = AttachPolicies(ctx, gql, p1, p2) err = AttachPolicies(ctx, &gql.SimpleNode, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, l1, p1, p2) err = AttachPolicies(ctx, &l1.SimpleNode, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, t1, p1, p2) err = AttachPolicies(ctx, &t1.SimpleNode, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, u1, p1, p2) err = AttachPolicies(ctx, &u1.SimpleNode, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
info := NewParentThreadInfo(true, "start", "restore") info := NewParentThreadInfo(true, "start", "restore")
context = NewWriteContext(ctx) context = NewWriteContext(ctx)
err = UpdateStates(context, gql, NewLockMap( err = UpdateStates(context, &gql, NewLockMap(
NewLockInfo(gql, []string{"users"}), NewLockInfo(&gql, []string{"users"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
gql.Users[KeyID(&u1_key.PublicKey)] = u1 gql.Users[KeyID(&u1_key.PublicKey)] = &u1
err := LinkThreads(context, gql, gql, t1, &info) err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{
"parent": &info,
}})
if err != nil { if err != nil {
return err return err
} }
return LinkLockables(context, gql, gql, []Lockable{l1}) return LinkLockables(context, &gql, &gql, []LockableNode{&l1})
}) })
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, gql, gql, NewStatusSignal("child_linked", t1.ID())) err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID()))
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, gql, gql, AbortSignal) err = Signal(context, &gql, &gql, AbortSignal)
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start") err = ThreadLoop(ctx, &gql, "start")
if errors.Is(err, ThreadAbortedError) == false { if errors.Is(err, ThreadAbortedError) == false {
fatalErr(t, err) fatalErr(t, err)
} }
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "aborted", 100*time.Millisecond, "Didn't receive aborted on update_channel") (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { err = UseStates(context, &gql, LockList([]Node{&gql, &u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize() ser1, err := gql.Serialize()
ser2, err := u1.Serialize() ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -116,18 +109,15 @@ func TestGQLDBLoad(t * testing.T) {
gql_loaded, err := LoadNode(ctx, gql.ID()) gql_loaded, err := LoadNode(ctx, gql.ID())
fatalErr(t, err) fatalErr(t, err)
var t1_loaded *SimpleThread = nil var l1_loaded *Listener = nil
var update_channel_2 chan GraphSignal
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error { err = UseStates(context, gql_loaded, NewLockInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize() ser, err := gql_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
dependency := gql_loaded.(*GQLThread).Thread.Dependencies[l1.ID()].(*Listener)
l1_loaded = dependency
u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()]
child := gql_loaded.(Thread).Children()[0].(*SimpleThread) err = UseStates(context, gql_loaded, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, RandID())
err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize() ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
return err return err
@ -136,9 +126,9 @@ func TestGQLDBLoad(t * testing.T) {
return err return err
}) })
err = ThreadLoop(ctx, gql_loaded.(Thread), "start") err = ThreadLoop(ctx, gql_loaded.(ThreadNode), "start")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") (*GraphTester)(t).WaitForStatus(ctx, l1_loaded.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
} }
@ -153,23 +143,19 @@ func TestGQLAuth(t * testing.T) {
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql_t := &gql_t_r gql_t := &gql_t_r
// p1 not written to DB, TODO: update write to follow links maybe l1 := NewListener(RandID(), "GQL Thread")
context := NewWriteContext(ctx) err = AttachPolicies(ctx, &l1.SimpleNode, p1)
err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error { fatalErr(t, err)
return gql_t.AddPolicy(p1)
})
err = AttachPolicies(ctx, &gql_t.SimpleNode, p1)
done := make(chan error, 1) done := make(chan error, 1)
var update_channel chan GraphSignal context := NewWriteContext(ctx)
context = NewReadContext(ctx) err = LinkLockables(context, gql_t, gql_t, []LockableNode{&l1})
err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error {
update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil
})
fatalErr(t, err) fatalErr(t, err)
go func(done chan error, thread Thread) {
go func(done chan error, thread ThreadNode) {
timeout := time.After(2*time.Second) timeout := time.After(2*time.Second)
select { select {
case <-timeout: case <-timeout:
@ -182,8 +168,8 @@ func TestGQLAuth(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
}(done, gql_t) }(done, gql_t)
go func(thread Thread){ go func(thread ThreadNode){
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "server_started", 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port) ctx.Log.Logf("test", "GQL_PORT: %d", port)

@ -142,9 +142,9 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
return gql_type_simple_lockable return gql_type_simple_lockable
}, nil) }, nil)
var GQLTypeGraphNode = NewSingleton(func() *graphql.Object { var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object {
object := graphql.NewObject(graphql.ObjectConfig{ object := graphql.NewObject(graphql.ObjectConfig{
Name: "GraphNode", Name: "SimpleNode",
Interfaces: []*graphql.Interface{ Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type, GQLInterfaceNode.Type,
}, },

@ -5,219 +5,151 @@ import (
"encoding/json" "encoding/json"
) )
// A Lockable represents a Node that can be locked and hold other Nodes locks type Listener struct {
type Lockable interface { Lockable
// All Lockables are nodes Chan chan GraphSignal
Node
//// State Modification Function
// Record that lockable was returned to it's owner and is no longer held by this Node
// Returns the previous owner of the lockable
RecordUnlock(lockable Lockable) Lockable
// Record that lockable was locked by this node, and that it should be returned to last_owner
RecordLock(lockable Lockable, last_owner Lockable)
// Link a requirement to this Node
AddRequirement(requirement Lockable)
// Remove a requirement linked to this Node
RemoveRequirement(requirement Lockable)
// Link a dependency to this Node
AddDependency(dependency Lockable)
// Remove a dependency linked to this Node
RemoveDependency(dependency Lockable)
//
SetOwner(new_owner Lockable)
//// State Reading Functions
Name() string
// Called when new_owner wants to take lockable's lock but it's owned by this node
// A true return value means that the lock can be passed
AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool
// Get all the linked requirements to this node
Requirements() []Lockable
// Get all the linked dependencies to this node
Dependencies() []Lockable
// Get the node's Owner
Owner() Lockable
// Called during the lock process after locking the state and before updating the Node's state
// a non-nil return value will abort the lock attempt
CanLock(new_owner Lockable) error
// Called during the unlock process after locking the state and before updating the Node's state
// a non-nil return value will abort the unlock attempt
CanUnlock(old_owner Lockable) error
}
// SimpleLockable is a simple Lockable implementation that can be embedded into more complex structures
type SimpleLockable struct {
GraphNode
name string
owner Lockable
requirements []Lockable
dependencies []Lockable
locks_held map[NodeID]Lockable
}
func (state * SimpleLockable) Type() NodeType {
return NodeType("simple_lockable")
}
type SimpleLockableJSON struct {
GraphNodeJSON
Name string `json:"name"`
Owner string `json:"owner"`
Dependencies []string `json:"dependencies"`
Requirements []string `json:"requirements"`
LocksHeld map[string]string `json:"locks_held"`
} }
func (lockable * SimpleLockable) Serialize() ([]byte, error) { func (node *Listener) Type() NodeType {
lockable_json := NewSimpleLockableJSON(lockable) return NodeType("listener")
return json.MarshalIndent(&lockable_json, "", " ")
}
func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
requirement_ids := make([]string, len(lockable.requirements))
for i, requirement := range(lockable.requirements) {
requirement_ids[i] = requirement.ID().String()
} }
dependency_ids := make([]string, len(lockable.dependencies)) func (node *Listener) Process(context *StateContext, signal GraphSignal) error {
for i, dependency := range(lockable.dependencies) { select {
dependency_ids[i] = dependency.ID().String() case node.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal)
} }
return node.Lockable.Process(context, signal)
owner_id := ""
if lockable.owner != nil {
owner_id = lockable.owner.ID().String()
} }
locks_held := map[string]string{} const LISTENER_CHANNEL_BUFFER = 1024
for lockable_id, node := range(lockable.locks_held) { func NewListener(id NodeID, name string) Listener {
if node == nil { return Listener{
locks_held[lockable_id.String()] = "" Lockable: NewLockable(id, name),
} else { Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER),
locks_held[lockable_id.String()] = node.ID().String()
} }
} }
node_json := NewGraphNodeJSON(&lockable.GraphNode) func LoadListener(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j LockableJSON
return SimpleLockableJSON{ err := json.Unmarshal(data, &j)
GraphNodeJSON: node_json, if err != nil {
Name: lockable.name, return nil, err
Owner: owner_id,
Dependencies: dependency_ids,
Requirements: requirement_ids,
LocksHeld: locks_held,
}
} }
func (lockable * SimpleLockable) Name() string { listener := NewListener(id, j.Name)
return lockable.name nodes[id] = &listener
}
func (lockable * SimpleLockable) RecordUnlock(l Lockable) Lockable { err = RestoreLockable(ctx, &listener.Lockable, j, nodes)
lockable_id := l.ID() if err != nil {
last_owner, exists := lockable.locks_held[lockable_id] return nil, err
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(lockable.locks_held, lockable_id)
return last_owner
} }
func (lockable * SimpleLockable) RecordLock(l Lockable, last_owner Lockable) { return &listener, nil
lockable_id := l.ID()
_, exists := lockable.locks_held[lockable_id]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
} }
lockable.locks_held[lockable_id] = last_owner type LockableNode interface {
Node
LockableHandle() *Lockable
} }
// Nothing can take a lock from a simple lockable // Lockable is a simple Lockable implementation that can be embedded into more complex structures
func (lockable * SimpleLockable) AllowedToTakeLock(l Lockable, new_owner Lockable) bool { type Lockable struct {
return false SimpleNode
Name string
Owner LockableNode
Requirements map[NodeID]LockableNode
Dependencies map[NodeID]LockableNode
LocksHeld map[NodeID]LockableNode
} }
func (lockable * SimpleLockable) Owner() Lockable { func (lockable *Lockable) LockableHandle() *Lockable {
return lockable.owner return lockable
} }
func (lockable * SimpleLockable) SetOwner(owner Lockable) { func (lockable *Lockable) Type() NodeType {
lockable.owner = owner return NodeType("lockable")
} }
func (lockable * SimpleLockable) Requirements() []Lockable { type LockableJSON struct {
return lockable.requirements SimpleNodeJSON
Name string `json:"name"`
Owner string `json:"owner"`
Dependencies []string `json:"dependencies"`
Requirements []string `json:"requirements"`
LocksHeld map[string]string `json:"locks_held"`
} }
func (lockable * SimpleLockable) AddRequirement(requirement Lockable) { func (lockable *Lockable) Serialize() ([]byte, error) {
if requirement == nil { lockable_json := NewLockableJSON(lockable)
panic("Will not connect nil to the DAG") return json.MarshalIndent(&lockable_json, "", " ")
}
lockable.requirements = append(lockable.requirements, requirement)
} }
func (lockable * SimpleLockable) Dependencies() []Lockable { func NewLockableJSON(lockable *Lockable) LockableJSON {
return lockable.dependencies requirement_ids := make([]string, len(lockable.Requirements))
req_n := 0
for id, _ := range(lockable.Requirements) {
requirement_ids[req_n] = id.String()
req_n++
} }
func (lockable * SimpleLockable) AddDependency(dependency Lockable) { dependency_ids := make([]string, len(lockable.Dependencies))
if dependency == nil { dep_n := 0
panic("Will not connect nil to the DAG") for id, _ := range(lockable.Dependencies) {
dependency_ids[dep_n] = id.String()
dep_n++
} }
lockable.dependencies = append(lockable.dependencies, dependency) owner_id := ""
if lockable.Owner != nil {
owner_id = lockable.Owner.ID().String()
} }
func (lockable * SimpleLockable) RemoveDependency(dependency Lockable) { locks_held := map[string]string{}
idx := -1 for lockable_id, node := range(lockable.LocksHeld) {
if node == nil {
for i, dep := range(lockable.dependencies) { locks_held[lockable_id.String()] = ""
if dep.ID() == dependency.ID() { } else {
idx = i locks_held[lockable_id.String()] = node.ID().String()
break
}
} }
if idx == -1 {
panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), lockable.Name()))
} }
dep_len := len(lockable.dependencies) node_json := NewSimpleNodeJSON(&lockable.SimpleNode)
lockable.dependencies[idx] = lockable.dependencies[dep_len-1]
lockable.dependencies = lockable.dependencies[0:(dep_len-1)]
}
func (lockable * SimpleLockable) RemoveRequirement(requirement Lockable) { return LockableJSON{
idx := -1 SimpleNodeJSON: node_json,
for i, req := range(lockable.requirements) { Name: lockable.Name,
if req.ID() == requirement.ID() { Owner: owner_id,
idx = i Dependencies: dependency_ids,
break Requirements: requirement_ids,
LocksHeld: locks_held,
} }
} }
if idx == -1 { func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode {
panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), lockable.Name())) lockable_id := l.ID()
last_owner, exists := lockable.LocksHeld[lockable_id]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
} }
delete(lockable.LocksHeld, lockable_id)
req_len := len(lockable.requirements) return last_owner
lockable.requirements[idx] = lockable.requirements[req_len-1]
lockable.requirements = lockable.requirements[0:(req_len-1)]
} }
func (lockable * SimpleLockable) CanLock(new_owner Lockable) error { func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) {
return nil lockable_id := l.ID()
_, exists := lockable.LocksHeld[lockable_id]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
} }
func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { lockable.LocksHeld[lockable_id] = last_owner
return nil
} }
// Assumed that lockable is already locked for signal // Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Process(context *StateContext, princ Node, signal GraphSignal) error { func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error {
err := lockable.GraphNode.Process(context, princ, signal) err := lockable.SimpleNode.Process(context, signal)
if err != nil { if err != nil {
return err return err
} }
@ -227,26 +159,26 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign
err = UseStates(context, lockable, err = UseStates(context, lockable,
NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error { NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false owner_sent := false
for _, dependency := range(lockable.dependencies) { for _, dependency := range(lockable.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
Signal(context, dependency, lockable, signal) Signal(context, dependency, lockable, signal)
if lockable.owner != nil { if lockable.Owner != nil {
if dependency.ID() == lockable.owner.ID() { if dependency.ID() == lockable.Owner.ID() {
owner_sent = true owner_sent = true
} }
} }
} }
if lockable.owner != nil && owner_sent == false { if lockable.Owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() { if lockable.Owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID())
return Signal(context, lockable.owner, lockable, signal) return Signal(context, lockable.Owner, lockable, signal)
} }
} }
return nil return nil
}) })
case Down: case Down:
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) { for _, requirement := range(lockable.Requirements) {
err := Signal(context, requirement, lockable, signal) err := Signal(context, requirement, lockable, signal)
if err != nil { if err != nil {
return err return err
@ -265,13 +197,13 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign
// Removes requirement as a requirement from lockable // Removes requirement as a requirement from lockable
// Continues the write context with princ, getting requirents for lockable and dependencies for requirement // Continues the write context with princ, getting requirents for lockable and dependencies for requirement
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks // Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error { func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error {
return UpdateStates(context, princ, LockMap{ return UpdateStates(context, princ, LockMap{
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}}, lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}}, requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error { }, func(context *StateContext) error {
var found Node = nil var found Node = nil
for _, req := range(lockable.Requirements()) { for _, req := range(lockable.LockableHandle().Requirements) {
if requirement.ID() == req.ID() { if requirement.ID() == req.ID() {
found = req found = req
break break
@ -282,8 +214,8 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID())
} }
requirement.RemoveDependency(lockable) delete(requirement.LockableHandle().Dependencies, lockable.ID())
lockable.RemoveRequirement(requirement) delete(lockable.LockableHandle().Requirements, requirement.ID())
return nil return nil
}) })
@ -291,10 +223,11 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi
// Link requirements as requirements to lockable // Link requirements as requirements to lockable
// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements // Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements
func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error { func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error {
if lockable == nil { if lockable_node == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
} }
lockable := lockable_node.LockableHandle()
if len(requirements) == 0 { if len(requirements) == 0 {
return nil return nil
@ -323,8 +256,10 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
), func(context *StateContext) error { ), func(context *StateContext) error {
// Check that all the requirements can be added // Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it // If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) { for _, requirement_node := range(requirements) {
for _, req := range(requirements) { requirement := requirement_node.LockableHandle()
for _, req_node := range(requirements) {
req := req_node.LockableHandle()
if req.ID() == requirement.ID() { if req.ID() == requirement.ID() {
continue continue
} }
@ -339,22 +274,23 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
if checkIfRequirement(context, requirement, lockable) == true { if checkIfRequirement(context, requirement, lockable) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID())
} }
if lockable.Owner() == nil { if lockable.Owner == nil {
// If the new owner isn't locked, we can add the requirement // If the new owner isn't locked, we can add the requirement
} else if requirement.Owner() == nil { } else if requirement.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID())
} else { } else {
// If the new requirement is already locked and the owner is already locked, their owners need to match // If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement.Owner().ID() != lockable.Owner().ID() { if requirement.Owner.ID() != lockable.Owner.ID() {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID())
} }
} }
} }
// Update the states of the requirements // Update the states of the requirements
for _, requirement := range(requirements) { for _, requirement_node := range(requirements) {
requirement.AddDependency(lockable) requirement := requirement_node.LockableHandle()
lockable.AddRequirement(requirement) requirement.Dependencies[lockable.ID()] = lockable_node
lockable.Requirements[lockable.ID()] = requirement_node
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID())
} }
@ -364,8 +300,8 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
} }
// Must be called withing update context // Must be called withing update context
func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool { func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool {
for _, c := range(cur.Requirements()) { for _, c := range(cur.LockableHandle().Requirements) {
if c.ID() == r.ID() { if c.ID() == r.ID() {
return true return true
} }
@ -385,7 +321,7 @@ func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool {
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error { func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error {
if to_lock == nil { if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided")
} }
@ -396,44 +332,41 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
} }
} }
if new_owner == nil { if new_owner_node == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
} }
new_owner := new_owner_node.LockableHandle()
// Called with no requirements to lock, success // Called with no requirements to lock, success
if len(to_lock) == 0 { if len(to_lock) == 0 {
return nil return nil
} }
return UpdateStates(context, new_owner, NewLockMap( return UpdateStates(context, new_owner, NewLockMap(
LockList(to_lock, []string{"lock"}), LockListM(to_lock, []string{"lock"}),
NewLockInfo(new_owner, nil), NewLockInfo(new_owner, nil),
), func(context *StateContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) { for _, req_node := range(to_lock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID())
// Check custom lock conditions
err := req.CanLock(new_owner)
if err != nil {
return err
}
// If req is alreay locked, check that we can pass the lock // If req is alreay locked, check that we can pass the lock
if req.Owner() != nil { if req.Owner != nil {
owner := req.Owner() owner := req.Owner
if owner.ID() == new_owner.ID() { if owner.ID() == new_owner.ID() {
continue continue
} else { } else {
err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req.Requirements(), req) return LockLockables(context, req.Requirements, req)
}) })
if err != nil { if err != nil {
return err return err
} }
} }
} else { } else {
err := LockLockables(context, req.Requirements(), req) err := LockLockables(context, req.Requirements, req)
if err != nil { if err != nil {
return err return err
} }
@ -441,19 +374,20 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
} }
// At this point state modification will be started, so no errors can be returned // At this point state modification will be started, so no errors can be returned
for _, req := range(to_lock) { for _, req_node := range(to_lock) {
old_owner := req.Owner() req := req_node.LockableHandle()
old_owner := req.Owner
// If the lockable was previously unowned, update the state // If the lockable was previously unowned, update the state
if old_owner == nil { if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID())
req.SetOwner(new_owner) req.Owner = new_owner_node
new_owner.RecordLock(req, old_owner) new_owner.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state // Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID() == new_owner.ID() { } else if old_owner.ID() == new_owner.ID() {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID())
// Otherwise update the state // Otherwise update the state
} else { } else {
req.SetOwner(new_owner) req.Owner = new_owner
new_owner.RecordLock(req, old_owner) new_owner.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID())
} }
@ -463,7 +397,7 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
} }
func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error { func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error {
if to_unlock == nil { if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
} }
@ -474,48 +408,46 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock
} }
} }
if old_owner == nil { if old_owner_node == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
} }
old_owner := old_owner_node.LockableHandle()
// Called with no requirements to unlock, success // Called with no requirements to unlock, success
if len(to_unlock) == 0 { if len(to_unlock) == 0 {
return nil return nil
} }
return UpdateStates(context, old_owner, NewLockMap( return UpdateStates(context, old_owner, NewLockMap(
LockList(to_unlock, []string{"lock"}), LockListM(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, nil), NewLockInfo(old_owner, nil),
), func(context *StateContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) { for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID())
// Check if the owner is correct // Check if the owner is correct
if req.Owner() != nil { if req.Owner != nil {
if req.Owner().ID() != old_owner.ID() { if req.Owner.ID() != old_owner.ID() {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID())
} }
} else { } else {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID())
} }
// Check custom unlock conditions err := UnlockLockables(context, req.Requirements, req)
err := req.CanUnlock(old_owner)
if err != nil {
return err
}
err = UnlockLockables(context, req.Requirements(), req)
if err != nil { if err != nil {
return err return err
} }
} }
// At this point state modification will be started, so no errors can be returned // At this point state modification will be started, so no errors can be returned
for _, req := range(to_unlock) { for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
new_owner := old_owner.RecordUnlock(req) new_owner := old_owner.RecordUnlock(req)
req.SetOwner(new_owner) req.Owner = new_owner
if new_owner == nil { if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID())
} else { } else {
@ -527,18 +459,18 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock
}) })
} }
// Load function for SimpleLockable // Load function for Lockable
func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { func LoadLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j SimpleLockableJSON var j LockableJSON
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
lockable := NewSimpleLockable(id, j.Name) lockable := NewLockable(id, j.Name)
nodes[id] = &lockable nodes[id] = &lockable
err = RestoreSimpleLockable(ctx, &lockable, j, nodes) err = RestoreLockable(ctx, &lockable, j, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -546,19 +478,19 @@ func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (No
return &lockable, nil return &lockable, nil
} }
func NewSimpleLockable(id NodeID, name string) SimpleLockable { func NewLockable(id NodeID, name string) Lockable {
return SimpleLockable{ return Lockable{
GraphNode: NewGraphNode(id), SimpleNode: NewSimpleNode(id),
name: name, Name: name,
owner: nil, Owner: nil,
requirements: []Lockable{}, Requirements: map[NodeID]LockableNode{},
dependencies: []Lockable{}, Dependencies: map[NodeID]LockableNode{},
locks_held: map[NodeID]Lockable{}, LocksHeld: map[NodeID]LockableNode{},
} }
} }
// Helper function to load links when loading a struct that embeds SimpleLockable // Helper function to load links when loading a struct that embeds Lockable
func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSON, nodes NodeMap) error { func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes NodeMap) error {
if j.Owner != "" { if j.Owner != "" {
owner_id, err := ParseID(j.Owner) owner_id, err := ParseID(j.Owner)
if err != nil { if err != nil {
@ -568,11 +500,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil { if err != nil {
return err return err
} }
owner, ok := owner_node.(Lockable) owner, ok := owner_node.(LockableNode)
if ok == false { if ok == false {
return fmt.Errorf("%s is not a Lockable", j.Owner) return fmt.Errorf("%s is not a Lockable", j.Owner)
} }
lockable.SetOwner(owner) lockable.Owner = owner
} }
for _, dep_str := range(j.Dependencies) { for _, dep_str := range(j.Dependencies) {
@ -584,11 +516,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil { if err != nil {
return err return err
} }
dep, ok := dep_node.(Lockable) dep, ok := dep_node.(LockableNode)
if ok == false { if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", dep_node) return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
} }
lockable.AddDependency(dep) lockable.Dependencies[dep_id] = dep
} }
for _, req_str := range(j.Requirements) { for _, req_str := range(j.Requirements) {
@ -600,11 +532,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil { if err != nil {
return err return err
} }
req, ok := req_node.(Lockable) req, ok := req_node.(LockableNode)
if ok == false { if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", req_node) return fmt.Errorf("%+v is not a Lockable as expected", req_node)
} }
lockable.AddRequirement(req) lockable.Requirements[req_id] = req
} }
for l_id_str, h_str := range(j.LocksHeld) { for l_id_str, h_str := range(j.LocksHeld) {
@ -613,12 +545,12 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil { if err != nil {
return err return err
} }
l_l, ok := l.(Lockable) l_l, ok := l.(LockableNode)
if ok == false { if ok == false {
return fmt.Errorf("%s is not a Lockable", l.ID()) return fmt.Errorf("%s is not a Lockable", l.ID())
} }
var h_l Lockable = nil var h_l LockableNode
if h_str != "" { if h_str != "" {
h_id, err := ParseID(h_str) h_id, err := ParseID(h_str)
if err != nil { if err != nil {
@ -628,7 +560,7 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil { if err != nil {
return err return err
} }
h, ok := h_node.(Lockable) h, ok := h_node.(LockableNode)
if ok == false { if ok == false {
return err return err
} }
@ -637,5 +569,5 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
lockable.RecordLock(l_l, h_l) lockable.RecordLock(l_l, h_l)
} }
return RestoreGraphNode(ctx, lockable, j.GraphNodeJSON, nodes) return RestoreSimpleNode(ctx, &lockable.SimpleNode, j.SimpleNodeJSON, nodes)
} }

@ -44,7 +44,7 @@ func KeyID(pub *ecdsa.PublicKey) NodeID {
// Types are how nodes are associated with structs at runtime(and from the DB) // Types are how nodes are associated with structs at runtime(and from the DB)
type NodeType string type NodeType string
func (node_type NodeType) Hash() uint64 { func (node_type NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(node_type)) hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type)))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
} }
@ -54,120 +54,87 @@ func RandID() NodeID {
return NodeID(uuid.New()) return NodeID(uuid.New())
} }
// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes)
// RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node
type Node interface { type Node interface {
// State Locking interface
sync.Locker
RLock()
RUnlock()
// Serialize the Node for the database
Serialize() ([]byte, error)
// Nodes have an ID, type, and ACL policies
ID() NodeID ID() NodeID
Type() NodeType Type() NodeType
Serialize() ([]byte, error)
Policies() map[NodeID]Policy LockState(write bool)
AddPolicy(Policy) error UnlockState(write bool)
RemovePolicy(Policy) error Process(context *StateContext, signal GraphSignal) error
Policies() []Policy
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Process(context *StateContext, princ Node, signal GraphSignal) error
// Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node
UnregisterChannel(id NodeID)
} }
// A GraphNode is an implementation of a Node that can be embedded into more complex structures type SimpleNode struct {
type GraphNode struct {
sync.RWMutex
listeners_lock sync.Mutex
id NodeID id NodeID
listeners map[NodeID]chan GraphSignal state_mutex sync.RWMutex
policies map[NodeID]Policy policies map[NodeID]Policy
} }
type GraphNodeJSON struct { func NewSimpleNode(id NodeID) SimpleNode {
Policies []string `json:"policies"` return SimpleNode{
id: id,
policies: map[NodeID]Policy{},
} }
func (node * GraphNode) Policies() map[NodeID]Policy {
return node.policies
} }
func (node * GraphNode) Serialize() ([]byte, error) { type SimpleNodeJSON struct {
node_json := NewGraphNodeJSON(node) Policies []string `json:"policies"`
return json.MarshalIndent(&node_json, "", " ")
} }
func Allowed(context *StateContext, policies map[NodeID]Policy, node Node, resource string, action string, princ Node) error { func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error {
if princ == nil { context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %+v", node.id, signal)
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == princ.ID() {
return nil return nil
} }
for _, policy := range(policies) {
if policy.Allows(node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil
}
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
}
func (node *GraphNode) AddPolicy(policy Policy) error { func (node *SimpleNode) ID() NodeID {
if policy == nil { return node.id
return fmt.Errorf("Cannot add nil as a policy")
} }
_, exists := node.policies[policy.ID()] func (node *SimpleNode) Type() NodeType {
if exists == true { return NodeType("simple_node")
return fmt.Errorf("%s is already a policy for %s", policy.ID().String(), node.ID().String())
} }
node.policies[policy.ID()] = policy func (node *SimpleNode) Serialize() ([]byte, error) {
return nil j := NewSimpleNodeJSON(node)
return json.MarshalIndent(&j, "", " ")
} }
func (node *GraphNode) RemovePolicy(policy Policy) error { func (node *SimpleNode) LockState(write bool) {
if policy == nil { if write == true {
return fmt.Errorf("Cannot add nil as a policy") node.state_mutex.Lock()
} else {
node.state_mutex.RLock()
} }
_, exists := node.policies[policy.ID()]
if exists == false {
return fmt.Errorf("%s is not a policy for %s", policy.ID().String(), node.ID().String())
} }
delete(node.policies, policy.ID()) func (node *SimpleNode) UnlockState(write bool) {
return nil if write == true {
node.state_mutex.Unlock()
} else {
node.state_mutex.RUnlock()
}
} }
func NewGraphNodeJSON(node *GraphNode) GraphNodeJSON { func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON {
policies := make([]string, len(node.policies)) policy_ids := make([]string, len(node.policies))
i := 0 i := 0
for _, policy := range(node.policies) { for id, _ := range(node.policies) {
policies[i] = policy.ID().String() policy_ids[i] = id.String()
i += 1 i += 1
} }
return GraphNodeJSON{
Policies: policies, return SimpleNodeJSON{
Policies: policy_ids,
} }
} }
func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) error { func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes NodeMap) error {
for _, policy_str := range(j.Policies) { for _, policy_str := range(j.Policies) {
policy_id, err := ParseID(policy_str) policy_id, err := ParseID(policy_str)
if err != nil { if err != nil {
return err return err
} }
policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes) policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes)
if err != nil { if err != nil {
return err return err
@ -177,26 +144,59 @@ func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) e
if ok == false { if ok == false {
return fmt.Errorf("%s is not a Policy", policy_id) return fmt.Errorf("%s is not a Policy", policy_id)
} }
node.AddPolicy(policy) node.policies[policy_id] = policy
} }
return nil return nil
} }
func LoadGraphNode(ctx * Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { func LoadSimpleNode(ctx *Context, id NodeID, data []byte, nodes NodeMap)(Node, error) {
if len(data) > 0 { var j SimpleNodeJSON
return nil, fmt.Errorf("Attempted to load a graph_node with data %+v, should have been 0 length", string(data)) err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
} }
node := NewGraphNode(id)
node := NewSimpleNode(id)
nodes[id] = &node
err = RestoreSimpleNode(ctx, &node, j, nodes)
if err != nil {
return nil, err
}
return &node, nil return &node, nil
} }
func (node * GraphNode) ID() NodeID { func (node *SimpleNode) Policies() []Policy {
return node.id ret := make([]Policy, len(node.policies))
i := 0
for _, policy := range(node.policies) {
ret[i] = policy
i += 1
}
return ret
} }
func (node * GraphNode) Type() NodeType { func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error {
return NodeType("graph_node") if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == princ.ID() {
return nil
}
for _, policy := range(policies) {
if policy.Allows(node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil
} }
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
}
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update // Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it // send it a notification that it was closed and then close it
@ -211,75 +211,19 @@ func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) er
return nil return nil
} }
return node.Process(context, princ, signal) return node.Process(context, signal)
}
func (node * GraphNode) Process(context *StateContext, princ Node, signal GraphSignal) error {
node.listeners_lock.Lock()
defer node.listeners_lock.Unlock()
closed := []NodeID{}
for id, listener := range node.listeners {
context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %s", node.ID(), id)
select {
case listener <- signal:
default:
context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %s", node.ID(), id)
go func(node Node, listener chan GraphSignal) {
listener <- NewDirectSignal("listener_closed")
close(listener)
}(node, listener)
closed = append(closed, id)
}
}
for _, id := range(closed) {
delete(node.listeners, id)
}
return nil
}
func (node * GraphNode) RegisterChannel(id NodeID, listener chan GraphSignal) {
node.listeners_lock.Lock()
_, exists := node.listeners[id]
if exists == false {
node.listeners[id] = listener
}
node.listeners_lock.Unlock()
}
func (node * GraphNode) UnregisterChannel(id NodeID) {
node.listeners_lock.Lock()
_, exists := node.listeners[id]
if exists == false {
panic("Attempting to unregister non-registered listener")
} else {
delete(node.listeners, id)
}
node.listeners_lock.Unlock()
} }
func AttachPolicies(ctx *Context, node Node, policies ...Policy) error { func AttachPolicies(ctx *Context, node *SimpleNode, policies ...Policy) error {
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error { return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error {
for _, policy := range(policies) { for _, policy := range(policies) {
err := node.AddPolicy(policy) node.policies[policy.ID()] = policy
if err != nil {
return err
}
} }
return nil return nil
}) })
} }
func NewGraphNode(id NodeID) GraphNode {
return GraphNode{
id: id,
listeners: map[NodeID]chan GraphSignal{},
policies: map[NodeID]Policy{},
}
}
// Magic first four bytes of serialized DB content, stored big endian // Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14 const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
@ -458,6 +402,17 @@ func NewLockMap(requests ...LockMap) LockMap {
return reqs return reqs
} }
func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap {
reqs := LockMap{}
for _, node := range(m) {
reqs[node.ID()] = LockInfo{
Node: node,
Resources: resources,
}
}
return reqs
}
func LockList[K Node](list []K, resources []string) LockMap { func LockList[K Node](list []K, resources []string) LockMap {
reqs := LockMap{} reqs := LockMap{}
for _, node := range(list) { for _, node := range(list) {
@ -565,7 +520,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if princ_locked == false { if princ_locked == false {
new_locks = append(new_locks, princ) new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String()) context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String())
princ.RLock() princ.LockState(false)
} }
princ_permissions, princ_exists := context.Permissions[princ.ID()] princ_permissions, princ_exists := context.Permissions[princ.ID()]
@ -588,7 +543,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if locked == false { if locked == false {
new_locks = append(new_locks, node) new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.RLock() node.LockState(false)
} }
} }
@ -610,7 +565,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock() n.UnlockState(false)
} }
return err return err
} }
@ -632,7 +587,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
for _, node := range(new_locks) { for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String()) context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String())
delete(context.Locked, node.ID()) delete(context.Locked, node.ID())
node.RUnlock() node.UnlockState(false)
} }
return err return err
@ -661,7 +616,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if princ_locked == false { if princ_locked == false {
new_locks = append(new_locks, princ) new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String()) context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String())
princ.Lock() princ.LockState(true)
} }
princ_permissions, princ_exists := context.Permissions[princ.ID()] princ_permissions, princ_exists := context.Permissions[princ.ID()]
@ -684,7 +639,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if locked == false { if locked == false {
new_locks = append(new_locks, node) new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.Lock() node.LockState(true)
} }
} }
@ -706,7 +661,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock() n.UnlockState(true)
} }
return err return err
} }
@ -730,19 +685,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
for id, node := range(context.Locked) { for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.Unlock() node.UnlockState(true)
} }
} }
return err return err
} }
// Create a new channel with a buffer the size of buffer, and register it to node with the id
func UpdateChannel(node Node, buffer int, id NodeID) chan GraphSignal {
if node == nil {
panic("Cannot get an update channel to nil")
}
new_listener := make(chan GraphSignal, buffer)
node.RegisterChannel(id, new_listener)
return new_listener
}

@ -44,12 +44,12 @@ func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) Nod
} }
type PerNodePolicy struct { type PerNodePolicy struct {
GraphNode SimpleNode
Actions map[NodeID]NodeActions Actions map[NodeID]NodeActions
} }
type PerNodePolicyJSON struct { type PerNodePolicyJSON struct {
GraphNodeJSON SimpleNodeJSON
Actions map[string]map[string][]string `json:"actions"` Actions map[string]map[string][]string `json:"actions"`
} }
@ -64,7 +64,7 @@ func (policy *PerNodePolicy) Serialize() ([]byte, error) {
} }
return json.MarshalIndent(&PerNodePolicyJSON{ return json.MarshalIndent(&PerNodePolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions, Actions: actions,
}, "", " ") }, "", " ")
} }
@ -75,7 +75,7 @@ func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy {
} }
return PerNodePolicy{ return PerNodePolicy{
GraphNode: NewGraphNode(id), SimpleNode: NewSimpleNode(id),
Actions: actions, Actions: actions,
} }
} }
@ -100,7 +100,7 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod
policy := NewPerNodePolicy(id, actions) policy := NewPerNodePolicy(id, actions)
nodes[id] = &policy nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,12 +122,12 @@ func (policy *PerNodePolicy) Allows(node Node, resource string, action string, p
} }
type SimplePolicy struct { type SimplePolicy struct {
GraphNode SimpleNode
Actions NodeActions Actions NodeActions
} }
type SimplePolicyJSON struct { type SimplePolicyJSON struct {
GraphNodeJSON SimpleNodeJSON
Actions map[string][]string `json:"actions"` Actions map[string][]string `json:"actions"`
} }
@ -137,7 +137,7 @@ func (policy *SimplePolicy) Type() NodeType {
func (policy *SimplePolicy) Serialize() ([]byte, error) { func (policy *SimplePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(&SimplePolicyJSON{ return json.MarshalIndent(&SimplePolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: policy.Actions, Actions: policy.Actions,
}, "", " ") }, "", " ")
} }
@ -148,7 +148,7 @@ func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy {
} }
return SimplePolicy{ return SimplePolicy{
GraphNode: NewGraphNode(id), SimpleNode: NewSimpleNode(id),
Actions: actions, Actions: actions,
} }
} }
@ -163,7 +163,7 @@ func LoadSimplePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
policy := NewSimplePolicy(id, j.Actions) policy := NewSimplePolicy(id, j.Actions)
nodes[id] = &policy nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -176,12 +176,12 @@ func (policy *SimplePolicy) Allows(node Node, resource string, action string, pr
} }
type PerTagPolicy struct { type PerTagPolicy struct {
GraphNode SimpleNode
Actions map[string]NodeActions Actions map[string]NodeActions
} }
type PerTagPolicyJSON struct { type PerTagPolicyJSON struct {
GraphNodeJSON SimpleNodeJSON
Actions map[string]map[string][]string `json:"json"` Actions map[string]map[string][]string `json:"json"`
} }
@ -196,7 +196,7 @@ func (policy *PerTagPolicy) Serialize() ([]byte, error) {
} }
return json.MarshalIndent(&PerTagPolicyJSON{ return json.MarshalIndent(&PerTagPolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions, Actions: actions,
}, "", " ") }, "", " ")
} }
@ -207,7 +207,7 @@ func NewPerTagPolicy(id NodeID, actions map[string]NodeActions) PerTagPolicy {
} }
return PerTagPolicy{ return PerTagPolicy{
GraphNode: NewGraphNode(id), SimpleNode: NewSimpleNode(id),
Actions: actions, Actions: actions,
} }
} }
@ -227,7 +227,7 @@ func LoadPerTagPolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
policy := NewPerTagPolicy(id, actions) policy := NewPerTagPolicy(id, actions)
nodes[id] = &policy nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -268,12 +268,12 @@ func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy {
} }
func (policy *DependencyPolicy) Allows(node Node, resource string, action string, principal Node) bool { func (policy *DependencyPolicy) Allows(node Node, resource string, action string, principal Node) bool {
lockable, ok := node.(Lockable) lockable, ok := node.(LockableNode)
if ok == false { if ok == false {
return false return false
} }
for _, dep := range(lockable.Dependencies()) { for _, dep := range(lockable.LockableHandle().Dependencies) {
if dep.ID() == principal.ID() { if dep.ID() == principal.ID() {
return policy.Actions.Allows(resource, action) return policy.Actions.Allows(resource, action)
} }

@ -5,13 +5,12 @@ import (
"time" "time"
"sync" "sync"
"errors" "errors"
"reflect"
"encoding/json" "encoding/json"
) )
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *SimpleThread) Process(context *StateContext, princ Node, signal GraphSignal) error { func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
err := thread.SimpleLockable.Process(context, princ, signal) err := thread.Lockable.Process(context, signal)
if err != nil { if err != nil {
return err return err
} }
@ -19,16 +18,16 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil { if thread.Parent != nil {
return Signal(context, thread.parent, thread, signal) return Signal(context, thread.Parent, thread, signal)
} else { } else {
return nil return nil
} }
}) })
case Down: case Down:
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, child := range(thread.children) { for _, info := range(thread.Children) {
err := Signal(context, child, thread, signal) err := Signal(context, info.Child, thread, signal)
if err != nil { if err != nil {
return err return err
} }
@ -44,136 +43,35 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr
return err return err
} }
thread.signal <- signal thread.Chan <- signal
return nil return nil
} }
// Interface to represent any type of thread information
type ThreadInfo interface {
}
func (thread * SimpleThread) SetTimeout(timeout time.Time, action string) {
thread.timeout = timeout
thread.timeout_action = action
thread.timeout_chan = time.After(time.Until(timeout))
}
func (thread * SimpleThread) TimeoutAction() string {
return thread.timeout_action
}
func (thread * SimpleThread) State() string {
return thread.state_name
}
func (thread * SimpleThread) SetState(new_state string) error {
if new_state == "" {
return fmt.Errorf("Cannot set state to '' with SetState")
}
thread.state_name = new_state
return nil
}
func (thread * SimpleThread) Parent() Thread {
return thread.parent
}
func (thread * SimpleThread) SetParent(parent Thread) {
thread.parent = parent
}
func (thread * SimpleThread) Children() []Thread {
return thread.children
}
func (thread * SimpleThread) Child(id NodeID) Thread {
for _, child := range(thread.children) {
if child.ID() == id {
return child
}
}
return nil
}
func (thread * SimpleThread) ChildInfo(child NodeID) ThreadInfo {
return thread.child_info[child]
}
// Requires thread and childs thread to be locked for write // Requires thread and childs thread to be locked for write
func UnlinkThreads(ctx * Context, thread Thread, child Thread) error { func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error {
var found Node = nil thread := node.ThreadHandle()
for _, c := range(thread.Children()) { child := child_node.ThreadHandle()
if child.ID() == c.ID() { _, is_child := thread.Children[child_node.ID()]
found = c if is_child == false {
break
}
}
if found == nil {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID())
} }
child.SetParent(nil) child.Parent = nil
thread.RemoveChild(child) delete(thread.Children, child.ID())
return nil return nil
} }
func (thread * SimpleThread) RemoveChild(child Thread) { func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool {
idx := -1 for _, info := range(cur.ThreadHandle().Children) {
for i, c := range(thread.children) { if info.Child.ID() == target.ID() {
if c.ID() == child.ID() {
idx = i
break
}
}
if idx == -1 {
panic(fmt.Sprintf("%s is not a child of %s", child.ID(), thread.Name()))
}
child_len := len(thread.children)
thread.children[idx] = thread.children[child_len-1]
thread.children = thread.children[0:child_len-1]
}
func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error {
if child == nil {
return fmt.Errorf("Will not connect nil to the thread tree")
}
_, exists := thread.child_info[child.ID()]
if exists == true {
return fmt.Errorf("Will not connect the same child twice")
}
if info == nil && thread.InfoType != nil {
return fmt.Errorf("nil info passed when expecting info")
} else if info != nil {
if reflect.TypeOf(info) != thread.InfoType {
return fmt.Errorf("info type mismatch, expecting %+v - %+v", thread.InfoType, reflect.TypeOf(info))
}
}
thread.children = append(thread.children, child)
thread.child_info[child.ID()] = info
return nil
}
func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
for _, child := range(cur.Children()) {
if child.ID() == target.ID() {
return true return true
} }
is_child := false is_child := false
UpdateStates(context, cur, NewLockMap( UpdateStates(context, cur, NewLockMap(
NewLockInfo(child, []string{"children"}), NewLockInfo(info.Child, []string{"children"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
is_child = checkIfChild(context, target, child) is_child = checkIfChild(context, target, info.Child)
return nil return nil
}) })
if is_child { if is_child {
@ -186,10 +84,12 @@ func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
// Links child to parent with info as the associated info // Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child // Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error {
if context == nil || thread == nil || child == nil { if context == nil || thread_node == nil || info.Child == nil {
return fmt.Errorf("invalid input") return fmt.Errorf("invalid input")
} }
thread := thread_node.ThreadHandle()
child := info.Child.ThreadHandle()
if thread.ID() == child.ID() { if thread.ID() == child.ID() {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
@ -199,7 +99,7 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread,
child.ID(): LockInfo{Node: child, Resources: []string{"parent"}}, child.ID(): LockInfo{Node: child, Resources: []string{"parent"}},
thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}}, thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error { }, func(context *StateContext) error {
if child.Parent() != nil { if child.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
} }
@ -211,60 +111,23 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread,
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID())
} }
err := thread.AddChild(child, info) // TODO check for info types
if err != nil {
return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err)
}
child.SetParent(thread)
if err != nil { thread.Children[child.ID()] = info
return err child.Parent = thread_node
}
return nil return nil
}) })
} }
type ThreadAction func(* Context, Thread)(string, error) type ThreadAction func(*Context, ThreadNode)(string, error)
type ThreadActions map[string]ThreadAction type ThreadActions map[string]ThreadAction
type ThreadHandler func(* Context, Thread, GraphSignal)(string, error) type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler type ThreadHandlers map[string]ThreadHandler
type Thread interface { type InfoType string
// All Threads are Lockables func (t InfoType) String() string {
Lockable return string(t)
/// State Modification Functions
SetParent(parent Thread)
AddChild(child Thread, info ThreadInfo) error
RemoveChild(child Thread)
SetState(new_thread string) error
SetTimeout(end_time time.Time, action string)
/// State Reading Functions
Parent() Thread
Children() []Thread
Child(id NodeID) Thread
ChildInfo(child NodeID) ThreadInfo
State() string
TimeoutAction() string
/// Functions that dont read/write thread
// Deserialize the attribute map from json.Unmarshal
DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error)
SetActive(active bool) error
Action(action string) (ThreadAction, bool)
Handler(signal_type string) (ThreadHandler, bool)
// Internal timeout channel for thread
Timeout() <-chan time.Time
// Internal signal channel for thread
SignalChannel() <-chan GraphSignal
ClearTimeout()
ChildWaits() *sync.WaitGroup
}
type ParentInfo interface {
Parent() *ParentThreadInfo
} }
// Data required by a parent thread to restore it's children // Data required by a parent thread to restore it's children
@ -274,10 +137,6 @@ type ParentThreadInfo struct {
RestoreAction string `json:"restore_action"` RestoreAction string `json:"restore_action"`
} }
func (info * ParentThreadInfo) Parent() *ParentThreadInfo{
return info
}
func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo { func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo {
return ParentThreadInfo{ return ParentThreadInfo{
Start: start, Start: start,
@ -286,83 +145,123 @@ func NewParentThreadInfo(start bool, start_action string, restore_action string)
} }
} }
type SimpleThread struct { type ChildInfo struct {
SimpleLockable Child ThreadNode
Infos map[InfoType]interface{}
}
actions ThreadActions func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo {
handlers ThreadHandlers if infos == nil {
infos = map[InfoType]interface{}{}
}
timeout_chan <-chan time.Time return ChildInfo{
signal chan GraphSignal Child: child,
child_waits *sync.WaitGroup Infos: infos,
active bool }
active_lock *sync.Mutex }
state_name string type QueuedAction struct {
parent Thread Timeout time.Time
children []Thread Action string
child_info map[NodeID] ThreadInfo
InfoType reflect.Type
timeout time.Time
timeout_action string
} }
func (thread * SimpleThread) Type() NodeType { type ThreadNode interface {
LockableNode
ThreadHandle() *Thread
}
type Thread struct {
Lockable
Actions ThreadActions
Handlers ThreadHandlers
TimeoutChan <-chan time.Time
Chan chan GraphSignal
ChildWaits sync.WaitGroup
Active bool
ActiveLock sync.Mutex
StateName string
Parent ThreadNode
Children map[NodeID]ChildInfo
InfoTypes []InfoType
TimeoutAction string
Timeout time.Time
}
func (thread *Thread) ThreadHandle() *Thread {
return thread
}
func (thread *Thread) Type() NodeType {
return NodeType("simple_thread") return NodeType("simple_thread")
} }
func (thread * SimpleThread) Serialize() ([]byte, error) { func (thread *Thread) Serialize() ([]byte, error) {
thread_json := NewSimpleThreadJSON(thread) thread_json := NewThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ") return json.MarshalIndent(&thread_json, "", " ")
} }
func (thread * SimpleThread) SignalChannel() <-chan GraphSignal { func (thread *Thread) ChildList() []ThreadNode {
return thread.signal ret := make([]ThreadNode, len(thread.Children))
i := 0
for _, info := range(thread.Children) {
ret[i] = info.Child
i += 1
}
return ret
} }
type SimpleThreadJSON struct { type ThreadJSON struct {
Parent string `json:"parent"` Parent string `json:"parent"`
Children map[string]interface{} `json:"children"` Children map[string]map[string]interface{} `json:"children"`
Timeout time.Time `json:"timeout"` Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"` TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"` StateName string `json:"state_name"`
SimpleLockableJSON LockableJSON
} }
func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON { func NewThreadJSON(thread *Thread) ThreadJSON {
children := map[string]interface{}{} children := map[string]map[string]interface{}{}
for _, child := range(thread.children) { for id, info := range(thread.Children) {
children[child.ID().String()] = thread.child_info[child.ID()] tmp := map[string]interface{}{}
for name, i := range(info.Infos) {
tmp[name.String()] = i
}
children[id.String()] = tmp
} }
parent_id := "" parent_id := ""
if thread.parent != nil { if thread.Parent != nil {
parent_id = thread.parent.ID().String() parent_id = thread.Parent.ID().String()
} }
lockable_json := NewSimpleLockableJSON(&thread.SimpleLockable) lockable_json := NewLockableJSON(&thread.Lockable)
return SimpleThreadJSON{ return ThreadJSON{
Parent: parent_id, Parent: parent_id,
Children: children, Children: children,
Timeout: thread.timeout, Timeout: thread.Timeout,
TimeoutAction: thread.timeout_action, TimeoutAction: thread.TimeoutAction,
StateName: thread.state_name, StateName: thread.StateName,
SimpleLockableJSON: lockable_json, LockableJSON: lockable_json,
} }
} }
func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { func LoadThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j SimpleThreadJSON var j ThreadJSON
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
thread := NewSimpleThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers) thread := NewThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers)
nodes[id] = &thread nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j, nodes) err = RestoreThread(ctx, &thread, j, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -370,17 +269,10 @@ func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
return &thread, nil return &thread, nil
} }
// SimpleThread has no associated info with children func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) error {
func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) {
if len(data) > 0 {
return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data))
}
return nil, nil
}
func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes NodeMap) error {
if j.TimeoutAction != "" { if j.TimeoutAction != "" {
thread.SetTimeout(j.Timeout, j.TimeoutAction) thread.Timeout = j.Timeout
thread.TimeoutAction = j.TimeoutAction
} }
if j.Parent != "" { if j.Parent != "" {
@ -392,11 +284,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
if err != nil { if err != nil {
return err return err
} }
p_t, ok := p.(Thread) p_t, ok := p.(ThreadNode)
if ok == false { if ok == false {
return err return err
} }
thread.SetParent(p_t) thread.Parent = p_t
} }
for id_str, info_raw := range(j.Children) { for id_str, info_raw := range(j.Children) {
@ -404,63 +296,94 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
if err != nil { if err != nil {
return err return err
} }
child_node, err := LoadNodeRecurse(ctx, id, nodes) child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil { if err != nil {
return err return err
} }
child_t, ok := child_node.(Thread)
child_t, ok := child_node.(ThreadNode)
if ok == false { if ok == false {
return fmt.Errorf("%+v is not a Thread as expected", child_node) return fmt.Errorf("%+v is not a Thread as expected", child_node)
} }
var info_ser []byte parsed_info, err := DeserializeChildInfo(ctx, info_raw)
if info_raw != nil {
info_ser, err = json.Marshal(info_raw)
if err != nil { if err != nil {
return err return err
} }
thread.Children[id] = ChildInfo{child_t, parsed_info}
} }
parsed_info, err := thread.DeserializeInfo(ctx, info_ser) return RestoreLockable(ctx, &thread.Lockable, j.LockableJSON, nodes)
if err != nil {
return err
} }
thread.AddChild(child_t, parsed_info) var deserializers = map[InfoType]func(interface{})(interface{}, error) {
}
func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) {
ret := map[InfoType]interface{}{}
for type_str, info_raw := range(infos_raw) {
info_type := InfoType(type_str)
deserializer, exists := deserializers[info_type]
if exists == false {
return nil, fmt.Errorf("No deserializer for %s", info_type)
}
var err error
ret[info_type], err = deserializer(info_raw)
if err != nil {
return nil, err
}
} }
return RestoreSimpleLockable(ctx, thread, j.SimpleLockableJSON, nodes) return ret, nil
} }
const THREAD_SIGNAL_BUFFER_SIZE = 128 const THREAD_SIGNAL_BUFFER_SIZE = 128
func NewSimpleThread(id NodeID, name string, state_name string, info_type reflect.Type, actions ThreadActions, handlers ThreadHandlers) SimpleThread { func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread {
return SimpleThread{ return Thread{
SimpleLockable: NewSimpleLockable(id, name), Lockable: NewLockable(id, name),
InfoType: info_type, InfoTypes: info_types,
state_name: state_name, StateName: state_name,
signal: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE), Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
children: []Thread{}, Children: map[NodeID]ChildInfo{},
child_info: map[NodeID]ThreadInfo{}, Actions: actions,
actions: actions, Handlers: handlers,
handlers: handlers, }
child_waits: &sync.WaitGroup{}, }
active_lock: &sync.Mutex{},
func (thread *Thread) SetActive(active bool) error {
thread.ActiveLock.Lock()
defer thread.ActiveLock.Unlock()
if thread.Active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.Active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
} }
thread.Active = active
return nil
}
func (thread *Thread) SetState(state string) error {
thread.StateName = state
return nil
} }
// Requires the read permission of threads children // Requires the read permission of threads children
func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread { func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode {
if thread == nil { if node == nil {
panic("cannot recurse through nil") panic("cannot recurse through nil")
} }
thread := node.ThreadHandle()
if id == thread.ID() { if id == thread.ID() {
return thread return thread
} }
for _, child := range thread.Children() { for _, info := range thread.Children {
var result Thread = nil var result ThreadNode
UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error { UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, child, id) result = FindChild(context, princ, info.Child, id)
return nil return nil
}) })
if result != nil { if result != nil {
@ -471,11 +394,11 @@ func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thre
return nil return nil
} }
func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) { func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) {
thread.ChildWaits().Add(1) thread.ChildWaits.Add(1)
go func(child Thread) { go func(child ThreadNode) {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID())
defer thread.ChildWaits().Done() defer thread.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action) err := ThreadLoop(ctx, child, first_action)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err) ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err)
@ -486,8 +409,9 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
} }
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context // Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, thread Thread, first_action string) error { func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error {
// Start the thread, error if double-started // Start the thread, error if double-started
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
err := thread.SetActive(true) err := thread.SetActive(true)
if err != nil { if err != nil {
@ -496,14 +420,14 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
} }
next_action := first_action next_action := first_action
for next_action != "" { for next_action != "" {
action, exists := thread.Action(next_action) action, exists := thread.Actions[next_action]
if exists == false { if exists == false {
error_str := fmt.Sprintf("%s is not a valid action", next_action) error_str := fmt.Sprintf("%s is not a valid action", next_action)
return errors.New(error_str) return errors.New(error_str)
} }
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action) ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action)
next_action, err = action(ctx, thread) next_action, err = action(ctx, node)
if err != nil { if err != nil {
return err return err
} }
@ -523,52 +447,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
return nil return nil
} }
func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
func (thread * SimpleThread) ChildWaits() *sync.WaitGroup { thread := node.ThreadHandle()
return thread.child_waits
}
func (thread * SimpleThread) SetActive(active bool) error {
thread.active_lock.Lock()
defer thread.active_lock.Unlock()
if thread.active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
}
thread.active = active
return nil
}
func (thread * SimpleThread) Action(action string) (ThreadAction, bool) {
action_fn, exists := thread.actions[action]
return action_fn, exists
}
func (thread * SimpleThread) Handler(signal_type string) (ThreadHandler, bool) {
handler, exists := thread.handlers[signal_type]
return handler, exists
}
func (thread * SimpleThread) Timeout() <-chan time.Time {
return thread.timeout_chan
}
func (thread * SimpleThread) ClearTimeout() {
thread.timeout_chan = nil
thread.timeout_action = ""
}
func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool {
for _, child := range(thread.children) {
if new_owner.ID() == child.ID() {
return true
}
}
return false
}
func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal) ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap( err := UpdateStates(context, thread, NewLockMap(
@ -579,18 +459,18 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST") ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
return nil return nil
} }
info_if := thread.ChildInfo(sig.ID) info, exists := thread.Children[sig.ID]
if info_if == nil { if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil return nil
} }
info_r, correct := info_if.(ParentInfo) parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
if correct == false { if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_INFO_CAST") panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
} }
info := info_r.Parent()
if info.Start == true { if parent_info.Start == true {
ChildGo(ctx, thread, thread.Child(sig.ID), info.StartAction) ChildGo(ctx, thread, info.Child, parent_info.StartAction)
} }
return nil return nil
}) })
@ -603,38 +483,30 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s
return "wait", nil return "wait", nil
} }
func ThreadParentStartChild(ctx *Context, thread Thread, signal GraphSignal) (string, error) { // Helper function to start a child from a thread during a signal handler
ctx.Log.Logf("thread", "THREAD_START_CHILD") // Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
sig, ok := signal.(StartChildSignal) sig, ok := signal.(StartChildSignal)
if ok == false { if ok == false {
ctx.Log.Logf("thread", "THREAD_START_CHILD_BAD_SIGNAL: %+v", signal)
return "wait", nil return "wait", nil
} }
err := ThreadStartChild(ctx, thread, sig) thread := node.ThreadHandle()
if err != nil {
ctx.Log.Logf("thread", "THREAD_START_CHILD_ERR: %s", err)
} else {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s", sig.ID.String())
}
return "wait", nil
}
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error {
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
child := thread.Child(signal.ID) info, exists:= thread.Children[sig.ID]
if child == nil { if exists == false {
return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID())
} }
return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error { return UpdateStates(context, thread, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error {
info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
info.Start = true if exists == false {
ChildGo(ctx, thread, child, signal.Action) return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
}
parent_info.Start = true
ChildGo(ctx, thread, info.Child, sig.Action)
return nil return nil
}) })
@ -643,18 +515,19 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro
// Helper function to restore threads that should be running from a parents restore action // Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread Thread, start bool) error { func ThreadRestore(ctx * Context, node ThreadNode, start bool) error {
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error { return UpdateStates(context, thread, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error {
for _, child := range(thread.Children()) { for _, info := range(thread.Children) {
info := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() parent_info := info.Infos["parent"].(*ParentThreadInfo)
if info.Start == true && child.State() != "finished" { if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID()) ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID())
if start == true { if start == true {
ChildGo(ctx, thread, child, info.StartAction) ChildGo(ctx, thread, info.Child, parent_info.StartAction)
} else { } else {
ChildGo(ctx, thread, child, info.RestoreAction) ChildGo(ctx, thread, info.Child, parent_info.RestoreAction)
} }
} }
} }
@ -665,10 +538,12 @@ func ThreadRestore(ctx * Context, thread Thread, start bool) error {
// Helper function to be called during a threads start action, sets the thread state to started // Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadStart(ctx * Context, thread Thread) error { // Returns "wait", nil on success, so the first return value can be ignored safely
func ThreadStart(ctx * Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, []Lockable{thread}, thread) err := LockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread)
if err != nil { if err != nil {
return err return err
} }
@ -676,39 +551,28 @@ func ThreadStart(ctx * Context, thread Thread) error {
}) })
} }
func ThreadDefaultStart(ctx * Context, thread Thread) (string, error) { func ThreadWait(ctx * Context, node ThreadNode) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID()) thread := node.ThreadHandle()
err := ThreadStart(ctx, thread) ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout)
if err != nil {
return "", err
}
return "wait", nil
}
func ThreadDefaultRestore(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID())
return "wait", nil
}
func ThreadWait(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout())
for { for {
select { select {
case signal := <- thread.SignalChannel(): case signal := <- thread.Chan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal)
signal_fn, exists := thread.Handler(signal.Type()) signal_fn, exists := thread.Handlers[signal.Type()]
if exists == true { if exists == true {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type())
return signal_fn(ctx, thread, signal) return signal_fn(ctx, thread, signal)
} else { } else {
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type())
} }
case <- thread.Timeout(): case <- thread.TimeoutChan:
timeout_action := "" timeout_action := ""
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error { err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.TimeoutAction() timeout_action = thread.TimeoutAction
thread.ClearTimeout() thread.TimeoutChan = nil
thread.TimeoutAction = ""
thread.Timeout = time.Time{}
return nil return nil
}) })
if err != nil { if err != nil {
@ -720,26 +584,23 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
} }
} }
func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) { func ThreadFinish(ctx *Context, node ThreadNode) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String()) thread := node.ThreadHandle()
return "", ThreadFinish(ctx, thread)
}
func ThreadFinish(ctx *Context, thread Thread) error {
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error { return "", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished") err := thread.SetState("finished")
if err != nil { if err != nil {
return err return err
} }
return UnlockLockables(context, []Lockable{thread}, thread) return UnlockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread)
}) })
} }
var ThreadAbortedError = errors.New("Thread aborted by signal") var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID())) err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID()))
if err != nil { if err != nil {
@ -749,38 +610,18 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
} }
// Default thread action for "stop", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID())) err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID()))
return "finish", err return "finish", err
} }
// Copy the default thread actions to a new ThreadActions map
func NewThreadActions() ThreadActions{
actions := ThreadActions{}
for k, v := range(BaseThreadActions) {
actions[k] = v
}
return actions
}
// Copy the defult thread handlers to a new ThreadAction map
func NewThreadHandlers() ThreadHandlers{
handlers := ThreadHandlers{}
for k, v := range(BaseThreadHandlers) {
handlers[k] = v
}
return handlers
}
// Default thread actions // Default thread actions
var BaseThreadActions = ThreadActions{ var BaseThreadActions = ThreadActions{
"wait": ThreadWait, "wait": ThreadWait,
"start": ThreadDefaultStart, "start": ThreadStart,
"finish": ThreadDefaultFinish, "finish": ThreadFinish,
"restore": ThreadDefaultRestore,
} }
// Default thread signal handlers // Default thread signal handlers

@ -9,7 +9,7 @@ import (
) )
type User struct { type User struct {
SimpleLockable Lockable
Granted time.Time Granted time.Time
Pubkey *ecdsa.PublicKey Pubkey *ecdsa.PublicKey
@ -18,7 +18,7 @@ type User struct {
} }
type UserJSON struct { type UserJSON struct {
SimpleLockableJSON LockableJSON
Granted time.Time `json:"granted"` Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"` Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"` Shared []byte `json:"shared"`
@ -30,14 +30,14 @@ func (user *User) Type() NodeType {
} }
func (user *User) Serialize() ([]byte, error) { func (user *User) Serialize() ([]byte, error) {
lockable_json := NewSimpleLockableJSON(&user.SimpleLockable) lockable_json := NewLockableJSON(&user.Lockable)
pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey) pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return json.MarshalIndent(&UserJSON{ return json.MarshalIndent(&UserJSON{
SimpleLockableJSON: lockable_json, LockableJSON: lockable_json,
Granted: user.Granted, Granted: user.Granted,
Shared: user.Shared, Shared: user.Shared,
Pubkey: pubkey, Pubkey: pubkey,
@ -68,7 +68,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error)
user := NewUser(j.Name, j.Granted, pubkey, j.Shared, j.Tags) user := NewUser(j.Name, j.Granted, pubkey, j.Shared, j.Tags)
nodes[id] = &user nodes[id] = &user
err = RestoreSimpleLockable(ctx, &user, j.SimpleLockableJSON, nodes) err = RestoreLockable(ctx, &user.Lockable, j.LockableJSON, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -79,7 +79,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error)
func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte, tags []string) User { func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte, tags []string) User {
id := KeyID(pubkey) id := KeyID(pubkey)
return User{ return User{
SimpleLockable: NewSimpleLockable(id, name), Lockable: NewLockable(id, name),
Granted: granted, Granted: granted,
Pubkey: pubkey, Pubkey: pubkey,
Shared: shared, Shared: shared,