Tests compile and run

graph-rework-2
noah metz 2023-07-24 16:04:56 -06:00
parent 201ee7234b
commit fc2e36043f
12 changed files with 713 additions and 983 deletions

@ -86,8 +86,8 @@ func (ctx * Context) RegisterNodeType(def NodeDef) error {
ctx.Types[type_hash] = def
node_type := reflect.TypeOf((*Node)(nil)).Elem()
lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem()
thread_type := reflect.TypeOf((*Thread)(nil)).Elem()
lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem()
thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem()
if def.Reflect.Implements(node_type) {
ctx.GQL.ValidNodes[def.Reflect] = def.GQLType
@ -154,7 +154,7 @@ func NewGQLContext() GQLContext {
Query: query,
Mutation: mutation,
Subscription: subscription,
BaseNodeType: GQLTypeGraphNode.Type,
BaseNodeType: GQLTypeSimpleNode.Type,
BaseLockableType: GQLTypeSimpleLockable.Type,
BaseThreadType: GQLTypeSimpleThread.Type,
}
@ -171,15 +171,19 @@ func NewContext(db * badger.DB, log Logger) * Context {
Types: map[uint64]NodeDef{},
}
err := ctx.RegisterNodeType(NewNodeDef((*GraphNode)(nil), LoadGraphNode, GQLTypeGraphNode.Type))
err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*SimpleLockable)(nil), LoadSimpleLockable, GQLTypeSimpleLockable.Type))
err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*SimpleThread)(nil), LoadSimpleThread, GQLTypeSimpleThread.Type))
err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type))
if err != nil {
panic(err)
}
@ -191,19 +195,19 @@ func NewContext(db * badger.DB, log Logger) * Context {
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeGraphNode.Type))
err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type))
err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeGraphNode.Type))
err = ctx.RegisterNodeType(NewNodeDef((*PerTagPolicy)(nil), LoadPerTagPolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeGraphNode.Type))
err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}

126
gql.go

@ -629,7 +629,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
}
type GQLThread struct {
SimpleThread
Thread
tcp_listener net.Listener
http_server *http.Server
http_done *sync.WaitGroup
@ -639,6 +639,37 @@ type GQLThread struct {
Users map[NodeID]*User
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
SubscribeLock sync.Mutex
SubscribeListeners []chan GraphSignal
}
func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal {
thread.SubscribeLock.Lock()
defer thread.SubscribeLock.Unlock()
new_listener := make(chan GraphSignal, buffer)
thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener)
return new_listener
}
func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error {
active_listeners := []chan GraphSignal{}
thread.SubscribeLock.Lock()
for _, listener := range(thread.SubscribeListeners) {
select {
case listener <- signal:
active_listeners = append(active_listeners, listener)
default:
go func(listener chan GraphSignal) {
listener <- NewDirectSignal("Channel Closed")
close(listener)
}(listener)
}
}
thread.SubscribeListeners = active_listeners
thread.SubscribeLock.Unlock()
return thread.Thread.Process(context, signal)
}
func (thread * GQLThread) Type() NodeType {
@ -650,17 +681,8 @@ func (thread * GQLThread) Serialize() ([]byte, error) {
return json.MarshalIndent(&thread_json, "", " ")
}
func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) {
var info ParentThreadInfo
err := json.Unmarshal(data, &info)
if err != nil {
return nil, err
}
return &info, nil
}
type GQLThreadJSON struct {
SimpleThreadJSON
ThreadJSON
Listen string `json:"listen"`
Users []string `json:"users"`
Key []byte `json:"key"`
@ -686,7 +708,7 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
}
func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
thread_json := NewSimpleThreadJSON(&thread.SimpleThread)
thread_json := NewThreadJSON(&thread.Thread)
ser_key, err := x509.MarshalECPrivateKey(thread.Key)
if err != nil {
@ -701,7 +723,7 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
}
return GQLThreadJSON{
SimpleThreadJSON: thread_json,
ThreadJSON: thread_json,
Listen: thread.Listen,
Users: users,
Key: ser_key,
@ -744,7 +766,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
thread.Users[id] = user.(*User)
}
err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes)
err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes)
if err != nil {
return nil, err
}
@ -793,8 +815,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
tls_key = ssl_key_pem
}
return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers),
Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers),
Listen: listen,
SubscribeListeners: []chan GraphSignal{},
Users: map[NodeID]*User{},
http_done: &sync.WaitGroup{},
Key: key,
@ -806,40 +829,23 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
var gql_actions ThreadActions = ThreadActions{
"wait": ThreadWait,
"restore": func(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID())
// Restore all the threads that have "Start" as true and arent in the "finished" state
err := ThreadRestore(ctx, thread, false)
if err != nil {
return "", err
}
return "start_server", nil
"restore": func(ctx *Context, node ThreadNode) (string, error) {
return "start_server", ThreadRestore(ctx, node, false)
},
"start": func(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("gql", "GQL_START")
err := ThreadStart(ctx, thread)
"start": func(ctx * Context, node ThreadNode) (string, error) {
_, err := ThreadStart(ctx, node)
if err != nil {
return "", err
}
// Start all the threads that have "Start" as true and arent in the "finished" state
err = ThreadRestore(ctx, thread, true)
if err != nil {
return "", err
}
return "start_server", nil
return "start_server", ThreadRestore(ctx, node, true)
},
"start_server": func(ctx * Context, thread Thread) (string, error) {
server, ok := thread.(*GQLThread)
if ok == false {
return "", fmt.Errorf("GQL_THREAD_START: %s is not GQLThread, %+v", thread.ID(), thread.State())
}
"start_server": func(ctx * Context, node ThreadNode) (string, error) {
gql_thread := node.(*GQLThread)
ctx.Log.Logf("gql", "GQL_START_SERVER")
// Serve the GQL http and ws handlers
mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, server))
mux.HandleFunc("/gql", GQLHandler(ctx, server))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server))
mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread))
mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread))
// Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler())
@ -849,7 +855,7 @@ var gql_actions ThreadActions = ThreadActions{
mux.Handle("/site/", http.StripPrefix("/site", fs))
http_server := &http.Server{
Addr: server.Listen,
Addr: gql_thread.Listen,
Handler: mux,
}
@ -858,7 +864,7 @@ var gql_actions ThreadActions = ThreadActions{
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
cert, err := tls.X509KeyPair(server.tls_cert, server.tls_key)
cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key)
if err != nil {
return "", err
}
@ -870,23 +876,23 @@ var gql_actions ThreadActions = ThreadActions{
listener := tls.NewListener(l, &config)
server.http_done.Add(1)
go func(server *GQLThread) {
defer server.http_done.Done()
gql_thread.http_done.Add(1)
go func(gql_thread *GQLThread) {
defer gql_thread.http_done.Done()
err := http_server.Serve(listener)
if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(server)
}(gql_thread)
context := NewWriteContext(ctx)
err = UpdateStates(context, server, NewLockMap(
NewLockInfo(server, []string{"http_server"}),
err = UpdateStates(context, gql_thread, NewLockMap(
NewLockInfo(gql_thread, []string{"http_server"}),
), func(context *StateContext) error {
server.tcp_listener = listener
server.http_server = http_server
gql_thread.tcp_listener = listener
gql_thread.http_server = http_server
return nil
})
@ -895,24 +901,24 @@ var gql_actions ThreadActions = ThreadActions{
}
context = NewReadContext(ctx)
err = Signal(context, server, server, NewStatusSignal("server_started", server.ID()))
err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID()))
if err != nil {
return "", err
}
return "wait", nil
},
"finish": func(ctx *Context, thread Thread) (string, error) {
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return "", ThreadFinish(ctx, thread)
"finish": func(ctx *Context, node ThreadNode) (string, error) {
gql_thread := node.(*GQLThread)
gql_thread.http_server.Shutdown(context.TODO())
gql_thread.http_done.Wait()
return ThreadFinish(ctx, node)
},
}
var gql_handlers ThreadHandlers = ThreadHandlers{
"child_linked": ThreadParentChildLinked,
"start_child": ThreadParentStartChild,
"child_linked": ThreadChildLinked,
"start_child": ThreadStartChild,
"abort": ThreadAbort,
"stop": ThreadStop,
}

@ -28,7 +28,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) (error){
node = FindChild(context, ctx.User, ctx.Server, id)
node = FindChild(context, ctx.User, &ctx.Server.Thread, id)
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
@ -86,13 +86,13 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) error {
node := FindChild(context, ctx.User, ctx.Server, parent_id)
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
parent := FindChild(context, ctx.User, &ctx.Server.Thread, parent_id)
if parent == nil {
return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID())
}
signal = NewStartChildSignal(child_id, action)
return Signal(context, node, ctx.User, signal)
return Signal(context, ctx.User, parent, signal)
})
if err != nil {
return nil, err

@ -105,15 +105,15 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Thread)
node, ok := p.Source.(*Thread)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
}
var parent Thread = nil
var parent ThreadNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.Parent()
parent = node.ThreadHandle().Parent
return nil
})
@ -130,7 +130,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Thread)
node, ok := p.Source.(ThreadNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
}
@ -138,7 +138,7 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
var state string
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.State()
state = node.ThreadHandle().StateName
return nil
})
@ -155,15 +155,20 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Thread)
node, ok := p.Source.(ThreadNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
}
var children []Thread = nil
var children []ThreadNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = node.Children()
children = make([]ThreadNode, len(node.ThreadHandle().Children))
i := 0
for _, info := range(node.ThreadHandle().Children) {
children[i] = info.Child
i += 1
}
return nil
})
@ -180,7 +185,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Lockable)
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
@ -188,7 +193,7 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
name := ""
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error {
name = node.Name()
name = node.LockableHandle().Name
return nil
})
@ -205,15 +210,20 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Lockable)
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
var requirements []Lockable = nil
var requirements []LockableNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = node.Requirements()
requirements = make([]LockableNode, len(node.LockableHandle().Requirements))
i := 0
for _, req := range(node.LockableHandle().Requirements) {
requirements[i] = req
i += 1
}
return nil
})
@ -230,15 +240,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Lockable)
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
var dependencies []Lockable = nil
var dependencies []LockableNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = node.Dependencies()
dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies))
i := 0
for _, dep := range(node.LockableHandle().Dependencies) {
dependencies[i] = dep
i += 1
}
return nil
})
@ -255,7 +270,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
node, ok := p.Source.(Lockable)
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
@ -263,7 +278,7 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
var owner Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.Owner()
owner = node.LockableHandle().Owner
return nil
})

@ -24,7 +24,7 @@ func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *G
c := make(chan interface{})
go func(c chan interface{}, server *GQLThread) {
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
sig_c := UpdateChannel(server, 1, RandID())
sig_c := server.NewSubscriptionChannel(1)
if send_nil == true {
sig_c <- nil
}

@ -20,46 +20,37 @@ import (
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
l1 := NewListener(RandID(), "Test Lockable 1")
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String())
listen_id := RandID()
ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String())
update_channel := UpdateChannel(t1, 10, listen_id)
u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
u1_shared := []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67}
u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"})
u1 := &u1_r
u1 := NewUser("Test User", time.Now(), &u1_key.PublicKey, []byte{}, []string{"gql"})
ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String())
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql := &gql_r
gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String())
// Policy to allow gql to perform all action on all resources
p1_r := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
p1 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
gql.ID(): NewNodeActions(nil, []string{"*"}),
})
p1 := &p1_r
p2_r := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
"signal": []string{"status"},
}, nil))
p2 := &p2_r
context := NewWriteContext(ctx)
err = UpdateStates(context, gql, LockMap{
p1.ID(): LockInfo{p1, nil},
p2.ID(): LockInfo{p2, nil},
err = UpdateStates(context, &gql, LockMap{
p1.ID(): LockInfo{&p1, nil},
p2.ID(): LockInfo{&p2, nil},
}, func(context *StateContext) error {
return nil
})
@ -67,46 +58,48 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String())
ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String())
err = AttachPolicies(ctx, gql, p1, p2)
err = AttachPolicies(ctx, &gql.SimpleNode, &p1, &p2)
fatalErr(t, err)
err = AttachPolicies(ctx, l1, p1, p2)
err = AttachPolicies(ctx, &l1.SimpleNode, &p1, &p2)
fatalErr(t, err)
err = AttachPolicies(ctx, t1, p1, p2)
err = AttachPolicies(ctx, &t1.SimpleNode, &p1, &p2)
fatalErr(t, err)
err = AttachPolicies(ctx, u1, p1, p2)
err = AttachPolicies(ctx, &u1.SimpleNode, &p1, &p2)
fatalErr(t, err)
info := NewParentThreadInfo(true, "start", "restore")
context = NewWriteContext(ctx)
err = UpdateStates(context, gql, NewLockMap(
NewLockInfo(gql, []string{"users"}),
err = UpdateStates(context, &gql, NewLockMap(
NewLockInfo(&gql, []string{"users"}),
), func(context *StateContext) error {
gql.Users[KeyID(&u1_key.PublicKey)] = u1
gql.Users[KeyID(&u1_key.PublicKey)] = &u1
err := LinkThreads(context, gql, gql, t1, &info)
err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{
"parent": &info,
}})
if err != nil {
return err
}
return LinkLockables(context, gql, gql, []Lockable{l1})
return LinkLockables(context, &gql, &gql, []LockableNode{&l1})
})
fatalErr(t, err)
context = NewReadContext(ctx)
err = Signal(context, gql, gql, NewStatusSignal("child_linked", t1.ID()))
err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID()))
fatalErr(t, err)
context = NewReadContext(ctx)
err = Signal(context, gql, gql, AbortSignal)
err = Signal(context, &gql, &gql, AbortSignal)
fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start")
err = ThreadLoop(ctx, &gql, "start")
if errors.Is(err, ThreadAbortedError) == false {
fatalErr(t, err)
}
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "aborted", 100*time.Millisecond, "Didn't receive aborted on update_channel")
(*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
err = UseStates(context, &gql, LockList([]Node{&gql, &u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -116,18 +109,15 @@ func TestGQLDBLoad(t * testing.T) {
gql_loaded, err := LoadNode(ctx, gql.ID())
fatalErr(t, err)
var t1_loaded *SimpleThread = nil
var update_channel_2 chan GraphSignal
var l1_loaded *Listener = nil
context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error {
err = UseStates(context, gql_loaded, NewLockInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
dependency := gql_loaded.(*GQLThread).Thread.Dependencies[l1.ID()].(*Listener)
l1_loaded = dependency
u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()]
child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, RandID())
err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
err = UseStates(context, gql_loaded, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
@ -136,9 +126,9 @@ func TestGQLDBLoad(t * testing.T) {
return err
})
err = ThreadLoop(ctx, gql_loaded.(Thread), "start")
err = ThreadLoop(ctx, gql_loaded.(ThreadNode), "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
(*GraphTester)(t).WaitForStatus(ctx, l1_loaded.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
}
@ -153,23 +143,19 @@ func TestGQLAuth(t * testing.T) {
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql_t := &gql_t_r
// p1 not written to DB, TODO: update write to follow links maybe
context := NewWriteContext(ctx)
err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error {
return gql_t.AddPolicy(p1)
})
l1 := NewListener(RandID(), "GQL Thread")
err = AttachPolicies(ctx, &l1.SimpleNode, p1)
fatalErr(t, err)
err = AttachPolicies(ctx, &gql_t.SimpleNode, p1)
done := make(chan error, 1)
var update_channel chan GraphSignal
context = NewReadContext(ctx)
err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error {
update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil
})
context := NewWriteContext(ctx)
err = LinkLockables(context, gql_t, gql_t, []LockableNode{&l1})
fatalErr(t, err)
go func(done chan error, thread Thread) {
go func(done chan error, thread ThreadNode) {
timeout := time.After(2*time.Second)
select {
case <-timeout:
@ -182,8 +168,8 @@ func TestGQLAuth(t * testing.T) {
fatalErr(t, err)
}(done, gql_t)
go func(thread Thread){
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start")
go func(thread ThreadNode){
(*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "server_started", 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port)

@ -142,9 +142,9 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
return gql_type_simple_lockable
}, nil)
var GQLTypeGraphNode = NewSingleton(func() *graphql.Object {
var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object {
object := graphql.NewObject(graphql.ObjectConfig{
Name: "GraphNode",
Name: "SimpleNode",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
},

@ -5,62 +5,75 @@ import (
"encoding/json"
)
// A Lockable represents a Node that can be locked and hold other Nodes locks
type Lockable interface {
// All Lockables are nodes
type Listener struct {
Lockable
Chan chan GraphSignal
}
func (node *Listener) Type() NodeType {
return NodeType("listener")
}
func (node *Listener) Process(context *StateContext, signal GraphSignal) error {
select {
case node.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal)
}
return node.Lockable.Process(context, signal)
}
const LISTENER_CHANNEL_BUFFER = 1024
func NewListener(id NodeID, name string) Listener {
return Listener{
Lockable: NewLockable(id, name),
Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER),
}
}
func LoadListener(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j LockableJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
listener := NewListener(id, j.Name)
nodes[id] = &listener
err = RestoreLockable(ctx, &listener.Lockable, j, nodes)
if err != nil {
return nil, err
}
return &listener, nil
}
type LockableNode interface {
Node
//// State Modification Function
// Record that lockable was returned to it's owner and is no longer held by this Node
// Returns the previous owner of the lockable
RecordUnlock(lockable Lockable) Lockable
// Record that lockable was locked by this node, and that it should be returned to last_owner
RecordLock(lockable Lockable, last_owner Lockable)
// Link a requirement to this Node
AddRequirement(requirement Lockable)
// Remove a requirement linked to this Node
RemoveRequirement(requirement Lockable)
// Link a dependency to this Node
AddDependency(dependency Lockable)
// Remove a dependency linked to this Node
RemoveDependency(dependency Lockable)
//
SetOwner(new_owner Lockable)
//// State Reading Functions
Name() string
// Called when new_owner wants to take lockable's lock but it's owned by this node
// A true return value means that the lock can be passed
AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool
// Get all the linked requirements to this node
Requirements() []Lockable
// Get all the linked dependencies to this node
Dependencies() []Lockable
// Get the node's Owner
Owner() Lockable
// Called during the lock process after locking the state and before updating the Node's state
// a non-nil return value will abort the lock attempt
CanLock(new_owner Lockable) error
// Called during the unlock process after locking the state and before updating the Node's state
// a non-nil return value will abort the unlock attempt
CanUnlock(old_owner Lockable) error
LockableHandle() *Lockable
}
// SimpleLockable is a simple Lockable implementation that can be embedded into more complex structures
type SimpleLockable struct {
GraphNode
name string
owner Lockable
requirements []Lockable
dependencies []Lockable
locks_held map[NodeID]Lockable
// Lockable is a simple Lockable implementation that can be embedded into more complex structures
type Lockable struct {
SimpleNode
Name string
Owner LockableNode
Requirements map[NodeID]LockableNode
Dependencies map[NodeID]LockableNode
LocksHeld map[NodeID]LockableNode
}
func (state * SimpleLockable) Type() NodeType {
return NodeType("simple_lockable")
func (lockable *Lockable) LockableHandle() *Lockable {
return lockable
}
type SimpleLockableJSON struct {
GraphNodeJSON
func (lockable *Lockable) Type() NodeType {
return NodeType("lockable")
}
type LockableJSON struct {
SimpleNodeJSON
Name string `json:"name"`
Owner string `json:"owner"`
Dependencies []string `json:"dependencies"`
@ -68,29 +81,33 @@ type SimpleLockableJSON struct {
LocksHeld map[string]string `json:"locks_held"`
}
func (lockable * SimpleLockable) Serialize() ([]byte, error) {
lockable_json := NewSimpleLockableJSON(lockable)
func (lockable *Lockable) Serialize() ([]byte, error) {
lockable_json := NewLockableJSON(lockable)
return json.MarshalIndent(&lockable_json, "", " ")
}
func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
requirement_ids := make([]string, len(lockable.requirements))
for i, requirement := range(lockable.requirements) {
requirement_ids[i] = requirement.ID().String()
func NewLockableJSON(lockable *Lockable) LockableJSON {
requirement_ids := make([]string, len(lockable.Requirements))
req_n := 0
for id, _ := range(lockable.Requirements) {
requirement_ids[req_n] = id.String()
req_n++
}
dependency_ids := make([]string, len(lockable.dependencies))
for i, dependency := range(lockable.dependencies) {
dependency_ids[i] = dependency.ID().String()
dependency_ids := make([]string, len(lockable.Dependencies))
dep_n := 0
for id, _ := range(lockable.Dependencies) {
dependency_ids[dep_n] = id.String()
dep_n++
}
owner_id := ""
if lockable.owner != nil {
owner_id = lockable.owner.ID().String()
if lockable.Owner != nil {
owner_id = lockable.Owner.ID().String()
}
locks_held := map[string]string{}
for lockable_id, node := range(lockable.locks_held) {
for lockable_id, node := range(lockable.LocksHeld) {
if node == nil {
locks_held[lockable_id.String()] = ""
} else {
@ -98,11 +115,11 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
}
}
node_json := NewGraphNodeJSON(&lockable.GraphNode)
node_json := NewSimpleNodeJSON(&lockable.SimpleNode)
return SimpleLockableJSON{
GraphNodeJSON: node_json,
Name: lockable.name,
return LockableJSON{
SimpleNodeJSON: node_json,
Name: lockable.Name,
Owner: owner_id,
Dependencies: dependency_ids,
Requirements: requirement_ids,
@ -110,114 +127,29 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
}
}
func (lockable * SimpleLockable) Name() string {
return lockable.name
}
func (lockable * SimpleLockable) RecordUnlock(l Lockable) Lockable {
func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode {
lockable_id := l.ID()
last_owner, exists := lockable.locks_held[lockable_id]
last_owner, exists := lockable.LocksHeld[lockable_id]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(lockable.locks_held, lockable_id)
delete(lockable.LocksHeld, lockable_id)
return last_owner
}
func (lockable * SimpleLockable) RecordLock(l Lockable, last_owner Lockable) {
func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) {
lockable_id := l.ID()
_, exists := lockable.locks_held[lockable_id]
_, exists := lockable.LocksHeld[lockable_id]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
lockable.locks_held[lockable_id] = last_owner
}
// Nothing can take a lock from a simple lockable
func (lockable * SimpleLockable) AllowedToTakeLock(l Lockable, new_owner Lockable) bool {
return false
}
func (lockable * SimpleLockable) Owner() Lockable {
return lockable.owner
}
func (lockable * SimpleLockable) SetOwner(owner Lockable) {
lockable.owner = owner
}
func (lockable * SimpleLockable) Requirements() []Lockable {
return lockable.requirements
}
func (lockable * SimpleLockable) AddRequirement(requirement Lockable) {
if requirement == nil {
panic("Will not connect nil to the DAG")
}
lockable.requirements = append(lockable.requirements, requirement)
}
func (lockable * SimpleLockable) Dependencies() []Lockable {
return lockable.dependencies
}
func (lockable * SimpleLockable) AddDependency(dependency Lockable) {
if dependency == nil {
panic("Will not connect nil to the DAG")
}
lockable.dependencies = append(lockable.dependencies, dependency)
}
func (lockable * SimpleLockable) RemoveDependency(dependency Lockable) {
idx := -1
for i, dep := range(lockable.dependencies) {
if dep.ID() == dependency.ID() {
idx = i
break
}
}
if idx == -1 {
panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), lockable.Name()))
}
dep_len := len(lockable.dependencies)
lockable.dependencies[idx] = lockable.dependencies[dep_len-1]
lockable.dependencies = lockable.dependencies[0:(dep_len-1)]
}
func (lockable * SimpleLockable) RemoveRequirement(requirement Lockable) {
idx := -1
for i, req := range(lockable.requirements) {
if req.ID() == requirement.ID() {
idx = i
break
}
}
if idx == -1 {
panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), lockable.Name()))
}
req_len := len(lockable.requirements)
lockable.requirements[idx] = lockable.requirements[req_len-1]
lockable.requirements = lockable.requirements[0:(req_len-1)]
}
func (lockable * SimpleLockable) CanLock(new_owner Lockable) error {
return nil
}
func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
return nil
lockable.LocksHeld[lockable_id] = last_owner
}
// Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Process(context *StateContext, princ Node, signal GraphSignal) error {
err := lockable.GraphNode.Process(context, princ, signal)
func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error {
err := lockable.SimpleNode.Process(context, signal)
if err != nil {
return err
}
@ -227,26 +159,26 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign
err = UseStates(context, lockable,
NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false
for _, dependency := range(lockable.dependencies) {
for _, dependency := range(lockable.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
Signal(context, dependency, lockable, signal)
if lockable.owner != nil {
if dependency.ID() == lockable.owner.ID() {
if lockable.Owner != nil {
if dependency.ID() == lockable.Owner.ID() {
owner_sent = true
}
}
}
if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return Signal(context, lockable.owner, lockable, signal)
if lockable.Owner != nil && owner_sent == false {
if lockable.Owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID())
return Signal(context, lockable.Owner, lockable, signal)
}
}
return nil
})
case Down:
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) {
for _, requirement := range(lockable.Requirements) {
err := Signal(context, requirement, lockable, signal)
if err != nil {
return err
@ -265,13 +197,13 @@ func (lockable * SimpleLockable) Process(context *StateContext, princ Node, sign
// Removes requirement as a requirement from lockable
// Continues the write context with princ, getting requirents for lockable and dependencies for requirement
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error {
func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error {
return UpdateStates(context, princ, LockMap{
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found Node = nil
for _, req := range(lockable.Requirements()) {
for _, req := range(lockable.LockableHandle().Requirements) {
if requirement.ID() == req.ID() {
found = req
break
@ -282,8 +214,8 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID())
}
requirement.RemoveDependency(lockable)
lockable.RemoveRequirement(requirement)
delete(requirement.LockableHandle().Dependencies, lockable.ID())
delete(lockable.LockableHandle().Requirements, requirement.ID())
return nil
})
@ -291,10 +223,11 @@ func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requi
// Link requirements as requirements to lockable
// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements
func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error {
if lockable == nil {
func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error {
if lockable_node == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
}
lockable := lockable_node.LockableHandle()
if len(requirements) == 0 {
return nil
@ -323,8 +256,10 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
), func(context *StateContext) error {
// Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) {
for _, req := range(requirements) {
for _, requirement_node := range(requirements) {
requirement := requirement_node.LockableHandle()
for _, req_node := range(requirements) {
req := req_node.LockableHandle()
if req.ID() == requirement.ID() {
continue
}
@ -339,22 +274,23 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
if checkIfRequirement(context, requirement, lockable) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID())
}
if lockable.Owner() == nil {
if lockable.Owner == nil {
// If the new owner isn't locked, we can add the requirement
} else if requirement.Owner() == nil {
} else if requirement.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID())
} else {
// If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement.Owner().ID() != lockable.Owner().ID() {
if requirement.Owner.ID() != lockable.Owner.ID() {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID())
}
}
}
// Update the states of the requirements
for _, requirement := range(requirements) {
requirement.AddDependency(lockable)
lockable.AddRequirement(requirement)
for _, requirement_node := range(requirements) {
requirement := requirement_node.LockableHandle()
requirement.Dependencies[lockable.ID()] = lockable_node
lockable.Requirements[lockable.ID()] = requirement_node
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID())
}
@ -364,8 +300,8 @@ func LinkLockables(context *StateContext, princ Node, lockable Lockable, require
}
// Must be called withing update context
func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool {
for _, c := range(cur.Requirements()) {
func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool {
for _, c := range(cur.LockableHandle().Requirements) {
if c.ID() == r.ID() {
return true
}
@ -385,7 +321,7 @@ func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool {
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error {
func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error {
if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided")
}
@ -396,44 +332,41 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
}
}
if new_owner == nil {
if new_owner_node == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
}
new_owner := new_owner_node.LockableHandle()
// Called with no requirements to lock, success
if len(to_lock) == 0 {
return nil
}
return UpdateStates(context, new_owner, NewLockMap(
LockList(to_lock, []string{"lock"}),
LockListM(to_lock, []string{"lock"}),
NewLockInfo(new_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) {
for _, req_node := range(to_lock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID())
// Check custom lock conditions
err := req.CanLock(new_owner)
if err != nil {
return err
}
// If req is alreay locked, check that we can pass the lock
if req.Owner() != nil {
owner := req.Owner()
if req.Owner != nil {
owner := req.Owner
if owner.ID() == new_owner.ID() {
continue
} else {
err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req.Requirements(), req)
return LockLockables(context, req.Requirements, req)
})
if err != nil {
return err
}
}
} else {
err := LockLockables(context, req.Requirements(), req)
err := LockLockables(context, req.Requirements, req)
if err != nil {
return err
}
@ -441,19 +374,20 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
}
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_lock) {
old_owner := req.Owner()
for _, req_node := range(to_lock) {
req := req_node.LockableHandle()
old_owner := req.Owner
// If the lockable was previously unowned, update the state
if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID())
req.SetOwner(new_owner)
req.Owner = new_owner_node
new_owner.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID() == new_owner.ID() {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID())
// Otherwise update the state
} else {
req.SetOwner(new_owner)
req.Owner = new_owner
new_owner.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID())
}
@ -463,7 +397,7 @@ func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable
}
func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error {
func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error {
if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
}
@ -474,48 +408,46 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock
}
}
if old_owner == nil {
if old_owner_node == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
}
old_owner := old_owner_node.LockableHandle()
// Called with no requirements to unlock, success
if len(to_unlock) == 0 {
return nil
}
return UpdateStates(context, old_owner, NewLockMap(
LockList(to_unlock, []string{"lock"}),
LockListM(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) {
for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID())
// Check if the owner is correct
if req.Owner() != nil {
if req.Owner().ID() != old_owner.ID() {
if req.Owner != nil {
if req.Owner.ID() != old_owner.ID() {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID())
}
} else {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID())
}
// Check custom unlock conditions
err := req.CanUnlock(old_owner)
if err != nil {
return err
}
err = UnlockLockables(context, req.Requirements(), req)
err := UnlockLockables(context, req.Requirements, req)
if err != nil {
return err
}
}
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_unlock) {
for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
new_owner := old_owner.RecordUnlock(req)
req.SetOwner(new_owner)
req.Owner = new_owner
if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID())
} else {
@ -527,18 +459,18 @@ func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lock
})
}
// Load function for SimpleLockable
func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j SimpleLockableJSON
// Load function for Lockable
func LoadLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j LockableJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
lockable := NewSimpleLockable(id, j.Name)
lockable := NewLockable(id, j.Name)
nodes[id] = &lockable
err = RestoreSimpleLockable(ctx, &lockable, j, nodes)
err = RestoreLockable(ctx, &lockable, j, nodes)
if err != nil {
return nil, err
}
@ -546,19 +478,19 @@ func LoadSimpleLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (No
return &lockable, nil
}
func NewSimpleLockable(id NodeID, name string) SimpleLockable {
return SimpleLockable{
GraphNode: NewGraphNode(id),
name: name,
owner: nil,
requirements: []Lockable{},
dependencies: []Lockable{},
locks_held: map[NodeID]Lockable{},
func NewLockable(id NodeID, name string) Lockable {
return Lockable{
SimpleNode: NewSimpleNode(id),
Name: name,
Owner: nil,
Requirements: map[NodeID]LockableNode{},
Dependencies: map[NodeID]LockableNode{},
LocksHeld: map[NodeID]LockableNode{},
}
}
// Helper function to load links when loading a struct that embeds SimpleLockable
func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSON, nodes NodeMap) error {
// Helper function to load links when loading a struct that embeds Lockable
func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes NodeMap) error {
if j.Owner != "" {
owner_id, err := ParseID(j.Owner)
if err != nil {
@ -568,11 +500,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil {
return err
}
owner, ok := owner_node.(Lockable)
owner, ok := owner_node.(LockableNode)
if ok == false {
return fmt.Errorf("%s is not a Lockable", j.Owner)
}
lockable.SetOwner(owner)
lockable.Owner = owner
}
for _, dep_str := range(j.Dependencies) {
@ -584,11 +516,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil {
return err
}
dep, ok := dep_node.(Lockable)
dep, ok := dep_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
}
lockable.AddDependency(dep)
lockable.Dependencies[dep_id] = dep
}
for _, req_str := range(j.Requirements) {
@ -600,11 +532,11 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil {
return err
}
req, ok := req_node.(Lockable)
req, ok := req_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", req_node)
}
lockable.AddRequirement(req)
lockable.Requirements[req_id] = req
}
for l_id_str, h_str := range(j.LocksHeld) {
@ -613,12 +545,12 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil {
return err
}
l_l, ok := l.(Lockable)
l_l, ok := l.(LockableNode)
if ok == false {
return fmt.Errorf("%s is not a Lockable", l.ID())
}
var h_l Lockable = nil
var h_l LockableNode
if h_str != "" {
h_id, err := ParseID(h_str)
if err != nil {
@ -628,7 +560,7 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
if err != nil {
return err
}
h, ok := h_node.(Lockable)
h, ok := h_node.(LockableNode)
if ok == false {
return err
}
@ -637,5 +569,5 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
lockable.RecordLock(l_l, h_l)
}
return RestoreGraphNode(ctx, lockable, j.GraphNodeJSON, nodes)
return RestoreSimpleNode(ctx, &lockable.SimpleNode, j.SimpleNodeJSON, nodes)
}

@ -44,7 +44,7 @@ func KeyID(pub *ecdsa.PublicKey) NodeID {
// Types are how nodes are associated with structs at runtime(and from the DB)
type NodeType string
func (node_type NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(node_type))
hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type)))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
@ -54,120 +54,87 @@ func RandID() NodeID {
return NodeID(uuid.New())
}
// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes)
// RegisterChannel and UnregisterChannel are used to connect arbitrary listeners to the node
type Node interface {
// State Locking interface
sync.Locker
RLock()
RUnlock()
// Serialize the Node for the database
Serialize() ([]byte, error)
// Nodes have an ID, type, and ACL policies
ID() NodeID
Type() NodeType
Policies() map[NodeID]Policy
AddPolicy(Policy) error
RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Process(context *StateContext, princ Node, signal GraphSignal) error
// Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node
UnregisterChannel(id NodeID)
Serialize() ([]byte, error)
LockState(write bool)
UnlockState(write bool)
Process(context *StateContext, signal GraphSignal) error
Policies() []Policy
}
// A GraphNode is an implementation of a Node that can be embedded into more complex structures
type GraphNode struct {
sync.RWMutex
listeners_lock sync.Mutex
type SimpleNode struct {
id NodeID
listeners map[NodeID]chan GraphSignal
state_mutex sync.RWMutex
policies map[NodeID]Policy
}
type GraphNodeJSON struct {
Policies []string `json:"policies"`
func NewSimpleNode(id NodeID) SimpleNode {
return SimpleNode{
id: id,
policies: map[NodeID]Policy{},
}
}
func (node * GraphNode) Policies() map[NodeID]Policy {
return node.policies
type SimpleNodeJSON struct {
Policies []string `json:"policies"`
}
func (node * GraphNode) Serialize() ([]byte, error) {
node_json := NewGraphNodeJSON(node)
return json.MarshalIndent(&node_json, "", " ")
func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %+v", node.id, signal)
return nil
}
func Allowed(context *StateContext, policies map[NodeID]Policy, node Node, resource string, action string, princ Node) error {
if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == princ.ID() {
return nil
}
for _, policy := range(policies) {
if policy.Allows(node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil
}
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
func (node *SimpleNode) ID() NodeID {
return node.id
}
func (node *GraphNode) AddPolicy(policy Policy) error {
if policy == nil {
return fmt.Errorf("Cannot add nil as a policy")
}
_, exists := node.policies[policy.ID()]
if exists == true {
return fmt.Errorf("%s is already a policy for %s", policy.ID().String(), node.ID().String())
}
func (node *SimpleNode) Type() NodeType {
return NodeType("simple_node")
}
node.policies[policy.ID()] = policy
return nil
func (node *SimpleNode) Serialize() ([]byte, error) {
j := NewSimpleNodeJSON(node)
return json.MarshalIndent(&j, "", " ")
}
func (node *GraphNode) RemovePolicy(policy Policy) error {
if policy == nil {
return fmt.Errorf("Cannot add nil as a policy")
func (node *SimpleNode) LockState(write bool) {
if write == true {
node.state_mutex.Lock()
} else {
node.state_mutex.RLock()
}
}
_, exists := node.policies[policy.ID()]
if exists == false {
return fmt.Errorf("%s is not a policy for %s", policy.ID().String(), node.ID().String())
func (node *SimpleNode) UnlockState(write bool) {
if write == true {
node.state_mutex.Unlock()
} else {
node.state_mutex.RUnlock()
}
delete(node.policies, policy.ID())
return nil
}
func NewGraphNodeJSON(node *GraphNode) GraphNodeJSON {
policies := make([]string, len(node.policies))
func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON {
policy_ids := make([]string, len(node.policies))
i := 0
for _, policy := range(node.policies) {
policies[i] = policy.ID().String()
for id, _ := range(node.policies) {
policy_ids[i] = id.String()
i += 1
}
return GraphNodeJSON{
Policies: policies,
return SimpleNodeJSON{
Policies: policy_ids,
}
}
func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) error {
func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes NodeMap) error {
for _, policy_str := range(j.Policies) {
policy_id, err := ParseID(policy_str)
if err != nil {
return err
}
policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes)
if err != nil {
return err
@ -177,27 +144,60 @@ func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) e
if ok == false {
return fmt.Errorf("%s is not a Policy", policy_id)
}
node.AddPolicy(policy)
node.policies[policy_id] = policy
}
return nil
}
func LoadGraphNode(ctx * Context, id NodeID, data []byte, nodes NodeMap)(Node, error) {
if len(data) > 0 {
return nil, fmt.Errorf("Attempted to load a graph_node with data %+v, should have been 0 length", string(data))
func LoadSimpleNode(ctx *Context, id NodeID, data []byte, nodes NodeMap)(Node, error) {
var j SimpleNodeJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
node := NewSimpleNode(id)
nodes[id] = &node
err = RestoreSimpleNode(ctx, &node, j, nodes)
if err != nil {
return nil, err
}
node := NewGraphNode(id)
return &node, nil
}
func (node * GraphNode) ID() NodeID {
return node.id
func (node *SimpleNode) Policies() []Policy {
ret := make([]Policy, len(node.policies))
i := 0
for _, policy := range(node.policies) {
ret[i] = policy
i += 1
}
return ret
}
func (node * GraphNode) Type() NodeType {
return NodeType("graph_node")
func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error {
if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == princ.ID() {
return nil
}
for _, policy := range(policies) {
if policy.Allows(node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil
}
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
}
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it
func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error {
@ -211,75 +211,19 @@ func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) er
return nil
}
return node.Process(context, princ, signal)
}
func (node * GraphNode) Process(context *StateContext, princ Node, signal GraphSignal) error {
node.listeners_lock.Lock()
defer node.listeners_lock.Unlock()
closed := []NodeID{}
for id, listener := range node.listeners {
context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %s", node.ID(), id)
select {
case listener <- signal:
default:
context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %s", node.ID(), id)
go func(node Node, listener chan GraphSignal) {
listener <- NewDirectSignal("listener_closed")
close(listener)
}(node, listener)
closed = append(closed, id)
}
}
for _, id := range(closed) {
delete(node.listeners, id)
}
return nil
return node.Process(context, signal)
}
func (node * GraphNode) RegisterChannel(id NodeID, listener chan GraphSignal) {
node.listeners_lock.Lock()
_, exists := node.listeners[id]
if exists == false {
node.listeners[id] = listener
}
node.listeners_lock.Unlock()
}
func (node * GraphNode) UnregisterChannel(id NodeID) {
node.listeners_lock.Lock()
_, exists := node.listeners[id]
if exists == false {
panic("Attempting to unregister non-registered listener")
} else {
delete(node.listeners, id)
}
node.listeners_lock.Unlock()
}
func AttachPolicies(ctx *Context, node Node, policies ...Policy) error {
func AttachPolicies(ctx *Context, node *SimpleNode, policies ...Policy) error {
context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error {
for _, policy := range(policies) {
err := node.AddPolicy(policy)
if err != nil {
return err
}
node.policies[policy.ID()] = policy
}
return nil
})
}
func NewGraphNode(id NodeID) GraphNode {
return GraphNode{
id: id,
listeners: map[NodeID]chan GraphSignal{},
policies: map[NodeID]Policy{},
}
}
// Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function
@ -458,6 +402,17 @@ func NewLockMap(requests ...LockMap) LockMap {
return reqs
}
func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap {
reqs := LockMap{}
for _, node := range(m) {
reqs[node.ID()] = LockInfo{
Node: node,
Resources: resources,
}
}
return reqs
}
func LockList[K Node](list []K, resources []string) LockMap {
reqs := LockMap{}
for _, node := range(list) {
@ -565,7 +520,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String())
princ.RLock()
princ.LockState(false)
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
@ -588,7 +543,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.RLock()
node.LockState(false)
}
}
@ -610,7 +565,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock()
n.UnlockState(false)
}
return err
}
@ -632,7 +587,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String())
delete(context.Locked, node.ID())
node.RUnlock()
node.UnlockState(false)
}
return err
@ -661,7 +616,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String())
princ.Lock()
princ.LockState(true)
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
@ -684,7 +639,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.Lock()
node.LockState(true)
}
}
@ -706,7 +661,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock()
n.UnlockState(true)
}
return err
}
@ -730,19 +685,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
}
for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.Unlock()
node.UnlockState(true)
}
}
return err
}
// Create a new channel with a buffer the size of buffer, and register it to node with the id
func UpdateChannel(node Node, buffer int, id NodeID) chan GraphSignal {
if node == nil {
panic("Cannot get an update channel to nil")
}
new_listener := make(chan GraphSignal, buffer)
node.RegisterChannel(id, new_listener)
return new_listener
}

@ -44,12 +44,12 @@ func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) Nod
}
type PerNodePolicy struct {
GraphNode
SimpleNode
Actions map[NodeID]NodeActions
}
type PerNodePolicyJSON struct {
GraphNodeJSON
SimpleNodeJSON
Actions map[string]map[string][]string `json:"actions"`
}
@ -64,7 +64,7 @@ func (policy *PerNodePolicy) Serialize() ([]byte, error) {
}
return json.MarshalIndent(&PerNodePolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode),
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions,
}, "", " ")
}
@ -75,7 +75,7 @@ func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy {
}
return PerNodePolicy{
GraphNode: NewGraphNode(id),
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
}
@ -100,7 +100,7 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod
policy := NewPerNodePolicy(id, actions)
nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes)
err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil {
return nil, err
}
@ -122,12 +122,12 @@ func (policy *PerNodePolicy) Allows(node Node, resource string, action string, p
}
type SimplePolicy struct {
GraphNode
SimpleNode
Actions NodeActions
}
type SimplePolicyJSON struct {
GraphNodeJSON
SimpleNodeJSON
Actions map[string][]string `json:"actions"`
}
@ -137,7 +137,7 @@ func (policy *SimplePolicy) Type() NodeType {
func (policy *SimplePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(&SimplePolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode),
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: policy.Actions,
}, "", " ")
}
@ -148,7 +148,7 @@ func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy {
}
return SimplePolicy{
GraphNode: NewGraphNode(id),
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
}
@ -163,7 +163,7 @@ func LoadSimplePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
policy := NewSimplePolicy(id, j.Actions)
nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes)
err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil {
return nil, err
}
@ -176,12 +176,12 @@ func (policy *SimplePolicy) Allows(node Node, resource string, action string, pr
}
type PerTagPolicy struct {
GraphNode
SimpleNode
Actions map[string]NodeActions
}
type PerTagPolicyJSON struct {
GraphNodeJSON
SimpleNodeJSON
Actions map[string]map[string][]string `json:"json"`
}
@ -196,7 +196,7 @@ func (policy *PerTagPolicy) Serialize() ([]byte, error) {
}
return json.MarshalIndent(&PerTagPolicyJSON{
GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode),
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions,
}, "", " ")
}
@ -207,7 +207,7 @@ func NewPerTagPolicy(id NodeID, actions map[string]NodeActions) PerTagPolicy {
}
return PerTagPolicy{
GraphNode: NewGraphNode(id),
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
}
@ -227,7 +227,7 @@ func LoadPerTagPolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
policy := NewPerTagPolicy(id, actions)
nodes[id] = &policy
err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes)
err = RestoreSimpleNode(ctx, &policy.SimpleNode, j.SimpleNodeJSON, nodes)
if err != nil {
return nil, err
}
@ -268,12 +268,12 @@ func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy {
}
func (policy *DependencyPolicy) Allows(node Node, resource string, action string, principal Node) bool {
lockable, ok := node.(Lockable)
lockable, ok := node.(LockableNode)
if ok == false {
return false
}
for _, dep := range(lockable.Dependencies()) {
for _, dep := range(lockable.LockableHandle().Dependencies) {
if dep.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
}

@ -5,13 +5,12 @@ import (
"time"
"sync"
"errors"
"reflect"
"encoding/json"
)
// Assumed that thread is already locked for signal
func (thread *SimpleThread) Process(context *StateContext, princ Node, signal GraphSignal) error {
err := thread.SimpleLockable.Process(context, princ, signal)
func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
err := thread.Lockable.Process(context, signal)
if err != nil {
return err
}
@ -19,16 +18,16 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr
switch signal.Direction() {
case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil {
return Signal(context, thread.parent, thread, signal)
if thread.Parent != nil {
return Signal(context, thread.Parent, thread, signal)
} else {
return nil
}
})
case Down:
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, child := range(thread.children) {
err := Signal(context, child, thread, signal)
for _, info := range(thread.Children) {
err := Signal(context, info.Child, thread, signal)
if err != nil {
return err
}
@ -44,136 +43,35 @@ func (thread *SimpleThread) Process(context *StateContext, princ Node, signal Gr
return err
}
thread.signal <- signal
thread.Chan <- signal
return nil
}
// Interface to represent any type of thread information
type ThreadInfo interface {
}
func (thread * SimpleThread) SetTimeout(timeout time.Time, action string) {
thread.timeout = timeout
thread.timeout_action = action
thread.timeout_chan = time.After(time.Until(timeout))
}
func (thread * SimpleThread) TimeoutAction() string {
return thread.timeout_action
}
func (thread * SimpleThread) State() string {
return thread.state_name
}
func (thread * SimpleThread) SetState(new_state string) error {
if new_state == "" {
return fmt.Errorf("Cannot set state to '' with SetState")
}
thread.state_name = new_state
return nil
}
func (thread * SimpleThread) Parent() Thread {
return thread.parent
}
func (thread * SimpleThread) SetParent(parent Thread) {
thread.parent = parent
}
func (thread * SimpleThread) Children() []Thread {
return thread.children
}
func (thread * SimpleThread) Child(id NodeID) Thread {
for _, child := range(thread.children) {
if child.ID() == id {
return child
}
}
return nil
}
func (thread * SimpleThread) ChildInfo(child NodeID) ThreadInfo {
return thread.child_info[child]
}
// Requires thread and childs thread to be locked for write
func UnlinkThreads(ctx * Context, thread Thread, child Thread) error {
var found Node = nil
for _, c := range(thread.Children()) {
if child.ID() == c.ID() {
found = c
break
}
}
if found == nil {
func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error {
thread := node.ThreadHandle()
child := child_node.ThreadHandle()
_, is_child := thread.Children[child_node.ID()]
if is_child == false {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID())
}
child.SetParent(nil)
thread.RemoveChild(child)
child.Parent = nil
delete(thread.Children, child.ID())
return nil
}
func (thread * SimpleThread) RemoveChild(child Thread) {
idx := -1
for i, c := range(thread.children) {
if c.ID() == child.ID() {
idx = i
break
}
}
if idx == -1 {
panic(fmt.Sprintf("%s is not a child of %s", child.ID(), thread.Name()))
}
child_len := len(thread.children)
thread.children[idx] = thread.children[child_len-1]
thread.children = thread.children[0:child_len-1]
}
func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error {
if child == nil {
return fmt.Errorf("Will not connect nil to the thread tree")
}
_, exists := thread.child_info[child.ID()]
if exists == true {
return fmt.Errorf("Will not connect the same child twice")
}
if info == nil && thread.InfoType != nil {
return fmt.Errorf("nil info passed when expecting info")
} else if info != nil {
if reflect.TypeOf(info) != thread.InfoType {
return fmt.Errorf("info type mismatch, expecting %+v - %+v", thread.InfoType, reflect.TypeOf(info))
}
}
thread.children = append(thread.children, child)
thread.child_info[child.ID()] = info
return nil
}
func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
for _, child := range(cur.Children()) {
if child.ID() == target.ID() {
func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool {
for _, info := range(cur.ThreadHandle().Children) {
if info.Child.ID() == target.ID() {
return true
}
is_child := false
UpdateStates(context, cur, NewLockMap(
NewLockInfo(child, []string{"children"}),
NewLockInfo(info.Child, []string{"children"}),
), func(context *StateContext) error {
is_child = checkIfChild(context, target, child)
is_child = checkIfChild(context, target, info.Child)
return nil
})
if is_child {
@ -186,10 +84,12 @@ func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
// Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error {
if context == nil || thread == nil || child == nil {
func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error {
if context == nil || thread_node == nil || info.Child == nil {
return fmt.Errorf("invalid input")
}
thread := thread_node.ThreadHandle()
child := info.Child.ThreadHandle()
if thread.ID() == child.ID() {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
@ -199,7 +99,7 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread,
child.ID(): LockInfo{Node: child, Resources: []string{"parent"}},
thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error {
if child.Parent() != nil {
if child.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
}
@ -211,60 +111,23 @@ func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread,
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID())
}
err := thread.AddChild(child, info)
if err != nil {
return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err)
}
child.SetParent(thread)
// TODO check for info types
if err != nil {
return err
}
thread.Children[child.ID()] = info
child.Parent = thread_node
return nil
})
}
type ThreadAction func(* Context, Thread)(string, error)
type ThreadAction func(*Context, ThreadNode)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(* Context, Thread, GraphSignal)(string, error)
type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
type Thread interface {
// All Threads are Lockables
Lockable
/// State Modification Functions
SetParent(parent Thread)
AddChild(child Thread, info ThreadInfo) error
RemoveChild(child Thread)
SetState(new_thread string) error
SetTimeout(end_time time.Time, action string)
/// State Reading Functions
Parent() Thread
Children() []Thread
Child(id NodeID) Thread
ChildInfo(child NodeID) ThreadInfo
State() string
TimeoutAction() string
/// Functions that dont read/write thread
// Deserialize the attribute map from json.Unmarshal
DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error)
SetActive(active bool) error
Action(action string) (ThreadAction, bool)
Handler(signal_type string) (ThreadHandler, bool)
// Internal timeout channel for thread
Timeout() <-chan time.Time
// Internal signal channel for thread
SignalChannel() <-chan GraphSignal
ClearTimeout()
ChildWaits() *sync.WaitGroup
}
type ParentInfo interface {
Parent() *ParentThreadInfo
type InfoType string
func (t InfoType) String() string {
return string(t)
}
// Data required by a parent thread to restore it's children
@ -274,10 +137,6 @@ type ParentThreadInfo struct {
RestoreAction string `json:"restore_action"`
}
func (info * ParentThreadInfo) Parent() *ParentThreadInfo{
return info
}
func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo {
return ParentThreadInfo{
Start: start,
@ -286,83 +145,123 @@ func NewParentThreadInfo(start bool, start_action string, restore_action string)
}
}
type SimpleThread struct {
SimpleLockable
type ChildInfo struct {
Child ThreadNode
Infos map[InfoType]interface{}
}
actions ThreadActions
handlers ThreadHandlers
func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo {
if infos == nil {
infos = map[InfoType]interface{}{}
}
timeout_chan <-chan time.Time
signal chan GraphSignal
child_waits *sync.WaitGroup
active bool
active_lock *sync.Mutex
return ChildInfo{
Child: child,
Infos: infos,
}
}
type QueuedAction struct {
Timeout time.Time
Action string
}
type ThreadNode interface {
LockableNode
ThreadHandle() *Thread
}
type Thread struct {
Lockable
Actions ThreadActions
Handlers ThreadHandlers
TimeoutChan <-chan time.Time
Chan chan GraphSignal
ChildWaits sync.WaitGroup
Active bool
ActiveLock sync.Mutex
StateName string
Parent ThreadNode
Children map[NodeID]ChildInfo
InfoTypes []InfoType
TimeoutAction string
Timeout time.Time
state_name string
parent Thread
children []Thread
child_info map[NodeID] ThreadInfo
InfoType reflect.Type
timeout time.Time
timeout_action string
}
func (thread * SimpleThread) Type() NodeType {
func (thread *Thread) ThreadHandle() *Thread {
return thread
}
func (thread *Thread) Type() NodeType {
return NodeType("simple_thread")
}
func (thread * SimpleThread) Serialize() ([]byte, error) {
thread_json := NewSimpleThreadJSON(thread)
func (thread *Thread) Serialize() ([]byte, error) {
thread_json := NewThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ")
}
func (thread * SimpleThread) SignalChannel() <-chan GraphSignal {
return thread.signal
func (thread *Thread) ChildList() []ThreadNode {
ret := make([]ThreadNode, len(thread.Children))
i := 0
for _, info := range(thread.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
type SimpleThreadJSON struct {
type ThreadJSON struct {
Parent string `json:"parent"`
Children map[string]interface{} `json:"children"`
Children map[string]map[string]interface{} `json:"children"`
Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"`
SimpleLockableJSON
LockableJSON
}
func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON {
children := map[string]interface{}{}
for _, child := range(thread.children) {
children[child.ID().String()] = thread.child_info[child.ID()]
func NewThreadJSON(thread *Thread) ThreadJSON {
children := map[string]map[string]interface{}{}
for id, info := range(thread.Children) {
tmp := map[string]interface{}{}
for name, i := range(info.Infos) {
tmp[name.String()] = i
}
children[id.String()] = tmp
}
parent_id := ""
if thread.parent != nil {
parent_id = thread.parent.ID().String()
if thread.Parent != nil {
parent_id = thread.Parent.ID().String()
}
lockable_json := NewSimpleLockableJSON(&thread.SimpleLockable)
lockable_json := NewLockableJSON(&thread.Lockable)
return SimpleThreadJSON{
return ThreadJSON{
Parent: parent_id,
Children: children,
Timeout: thread.timeout,
TimeoutAction: thread.timeout_action,
StateName: thread.state_name,
SimpleLockableJSON: lockable_json,
Timeout: thread.Timeout,
TimeoutAction: thread.TimeoutAction,
StateName: thread.StateName,
LockableJSON: lockable_json,
}
}
func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j SimpleThreadJSON
func LoadThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j ThreadJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
thread := NewSimpleThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers)
thread := NewThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers)
nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j, nodes)
err = RestoreThread(ctx, &thread, j, nodes)
if err != nil {
return nil, err
}
@ -370,17 +269,10 @@ func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node
return &thread, nil
}
// SimpleThread has no associated info with children
func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) {
if len(data) > 0 {
return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data))
}
return nil, nil
}
func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes NodeMap) error {
func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) error {
if j.TimeoutAction != "" {
thread.SetTimeout(j.Timeout, j.TimeoutAction)
thread.Timeout = j.Timeout
thread.TimeoutAction = j.TimeoutAction
}
if j.Parent != "" {
@ -392,11 +284,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
if err != nil {
return err
}
p_t, ok := p.(Thread)
p_t, ok := p.(ThreadNode)
if ok == false {
return err
}
thread.SetParent(p_t)
thread.Parent = p_t
}
for id_str, info_raw := range(j.Children) {
@ -404,63 +296,94 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
if err != nil {
return err
}
child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
}
child_t, ok := child_node.(Thread)
child_t, ok := child_node.(ThreadNode)
if ok == false {
return fmt.Errorf("%+v is not a Thread as expected", child_node)
}
var info_ser []byte
if info_raw != nil {
info_ser, err = json.Marshal(info_raw)
if err != nil {
return err
}
}
parsed_info, err := thread.DeserializeInfo(ctx, info_ser)
parsed_info, err := DeserializeChildInfo(ctx, info_raw)
if err != nil {
return err
}
thread.AddChild(child_t, parsed_info)
thread.Children[id] = ChildInfo{child_t, parsed_info}
}
return RestoreLockable(ctx, &thread.Lockable, j.LockableJSON, nodes)
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) {
}
func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) {
ret := map[InfoType]interface{}{}
for type_str, info_raw := range(infos_raw) {
info_type := InfoType(type_str)
deserializer, exists := deserializers[info_type]
if exists == false {
return nil, fmt.Errorf("No deserializer for %s", info_type)
}
var err error
ret[info_type], err = deserializer(info_raw)
if err != nil {
return nil, err
}
}
return RestoreSimpleLockable(ctx, thread, j.SimpleLockableJSON, nodes)
return ret, nil
}
const THREAD_SIGNAL_BUFFER_SIZE = 128
func NewSimpleThread(id NodeID, name string, state_name string, info_type reflect.Type, actions ThreadActions, handlers ThreadHandlers) SimpleThread {
return SimpleThread{
SimpleLockable: NewSimpleLockable(id, name),
InfoType: info_type,
state_name: state_name,
signal: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
children: []Thread{},
child_info: map[NodeID]ThreadInfo{},
actions: actions,
handlers: handlers,
child_waits: &sync.WaitGroup{},
active_lock: &sync.Mutex{},
func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread {
return Thread{
Lockable: NewLockable(id, name),
InfoTypes: info_types,
StateName: state_name,
Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
Children: map[NodeID]ChildInfo{},
Actions: actions,
Handlers: handlers,
}
}
func (thread *Thread) SetActive(active bool) error {
thread.ActiveLock.Lock()
defer thread.ActiveLock.Unlock()
if thread.Active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.Active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
}
thread.Active = active
return nil
}
func (thread *Thread) SetState(state string) error {
thread.StateName = state
return nil
}
// Requires the read permission of threads children
func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread {
if thread == nil {
func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode {
if node == nil {
panic("cannot recurse through nil")
}
thread := node.ThreadHandle()
if id == thread.ID() {
return thread
}
for _, child := range thread.Children() {
var result Thread = nil
UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, child, id)
for _, info := range thread.Children {
var result ThreadNode
UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, info.Child, id)
return nil
})
if result != nil {
@ -471,11 +394,11 @@ func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thre
return nil
}
func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
thread.ChildWaits().Add(1)
go func(child Thread) {
func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) {
thread.ChildWaits.Add(1)
go func(child ThreadNode) {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID())
defer thread.ChildWaits().Done()
defer thread.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action)
if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err)
@ -486,8 +409,9 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
}
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error {
// Start the thread, error if double-started
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
err := thread.SetActive(true)
if err != nil {
@ -496,14 +420,14 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
}
next_action := first_action
for next_action != "" {
action, exists := thread.Action(next_action)
action, exists := thread.Actions[next_action]
if exists == false {
error_str := fmt.Sprintf("%s is not a valid action", next_action)
return errors.New(error_str)
}
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action)
next_action, err = action(ctx, thread)
next_action, err = action(ctx, node)
if err != nil {
return err
}
@ -523,52 +447,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
return nil
}
func (thread * SimpleThread) ChildWaits() *sync.WaitGroup {
return thread.child_waits
}
func (thread * SimpleThread) SetActive(active bool) error {
thread.active_lock.Lock()
defer thread.active_lock.Unlock()
if thread.active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
}
thread.active = active
return nil
}
func (thread * SimpleThread) Action(action string) (ThreadAction, bool) {
action_fn, exists := thread.actions[action]
return action_fn, exists
}
func (thread * SimpleThread) Handler(signal_type string) (ThreadHandler, bool) {
handler, exists := thread.handlers[signal_type]
return handler, exists
}
func (thread * SimpleThread) Timeout() <-chan time.Time {
return thread.timeout_chan
}
func (thread * SimpleThread) ClearTimeout() {
thread.timeout_chan = nil
thread.timeout_action = ""
}
func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lockable) bool {
for _, child := range(thread.children) {
if new_owner.ID() == child.ID() {
return true
}
}
return false
}
func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (string, error) {
func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal)
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(
@ -579,18 +459,18 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
return nil
}
info_if := thread.ChildInfo(sig.ID)
if info_if == nil {
info, exists := thread.Children[sig.ID]
if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil
}
info_r, correct := info_if.(ParentInfo)
if correct == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_INFO_CAST")
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
if exists == false {
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
}
info := info_r.Parent()
if info.Start == true {
ChildGo(ctx, thread, thread.Child(sig.ID), info.StartAction)
if parent_info.Start == true {
ChildGo(ctx, thread, info.Child, parent_info.StartAction)
}
return nil
})
@ -603,38 +483,30 @@ func ThreadParentChildLinked(ctx *Context, thread Thread, signal GraphSignal) (s
return "wait", nil
}
func ThreadParentStartChild(ctx *Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("thread", "THREAD_START_CHILD")
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
sig, ok := signal.(StartChildSignal)
if ok == false {
ctx.Log.Logf("thread", "THREAD_START_CHILD_BAD_SIGNAL: %+v", signal)
return "wait", nil
}
err := ThreadStartChild(ctx, thread, sig)
if err != nil {
ctx.Log.Logf("thread", "THREAD_START_CHILD_ERR: %s", err)
} else {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s", sig.ID.String())
}
return "wait", nil
}
thread := node.ThreadHandle()
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
child := thread.Child(signal.ID)
if child == nil {
return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID())
return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
info, exists:= thread.Children[sig.ID]
if exists == false {
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID())
}
return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error {
return UpdateStates(context, thread, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error {
info := thread.ChildInfo(signal.ID).(*ParentThreadInfo)
info.Start = true
ChildGo(ctx, thread, child, signal.Action)
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
if exists == false {
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
}
parent_info.Start = true
ChildGo(ctx, thread, info.Child, sig.Action)
return nil
})
@ -643,18 +515,19 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro
// Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread Thread, start bool) error {
func ThreadRestore(ctx * Context, node ThreadNode, start bool) error {
thread := node.ThreadHandle()
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error {
for _, child := range(thread.Children()) {
info := (thread.ChildInfo(child.ID())).(ParentInfo).Parent()
if info.Start == true && child.State() != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID())
return UpdateStates(context, thread, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error {
for _, info := range(thread.Children) {
parent_info := info.Infos["parent"].(*ParentThreadInfo)
if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID())
if start == true {
ChildGo(ctx, thread, child, info.StartAction)
ChildGo(ctx, thread, info.Child, parent_info.StartAction)
} else {
ChildGo(ctx, thread, child, info.RestoreAction)
ChildGo(ctx, thread, info.Child, parent_info.RestoreAction)
}
}
}
@ -665,10 +538,12 @@ func ThreadRestore(ctx * Context, thread Thread, start bool) error {
// Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context
func ThreadStart(ctx * Context, thread Thread) error {
// Returns "wait", nil on success, so the first return value can be ignored safely
func ThreadStart(ctx * Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, []Lockable{thread}, thread)
return "wait", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread)
if err != nil {
return err
}
@ -676,39 +551,28 @@ func ThreadStart(ctx * Context, thread Thread) error {
})
}
func ThreadDefaultStart(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID())
err := ThreadStart(ctx, thread)
if err != nil {
return "", err
}
return "wait", nil
}
func ThreadDefaultRestore(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID())
return "wait", nil
}
func ThreadWait(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout())
func ThreadWait(ctx * Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout)
for {
select {
case signal := <- thread.SignalChannel():
case signal := <- thread.Chan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal)
signal_fn, exists := thread.Handler(signal.Type())
signal_fn, exists := thread.Handlers[signal.Type()]
if exists == true {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type())
return signal_fn(ctx, thread, signal)
} else {
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type())
}
case <- thread.Timeout():
case <- thread.TimeoutChan:
timeout_action := ""
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.TimeoutAction()
thread.ClearTimeout()
timeout_action = thread.TimeoutAction
thread.TimeoutChan = nil
thread.TimeoutAction = ""
thread.Timeout = time.Time{}
return nil
})
if err != nil {
@ -720,26 +584,23 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
}
}
func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String())
return "", ThreadFinish(ctx, thread)
}
func ThreadFinish(ctx *Context, thread Thread) error {
func ThreadFinish(ctx *Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
return "", UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, []Lockable{thread}, thread)
return UnlockLockables(context, map[NodeID]LockableNode{thread.ID(): thread}, thread)
})
}
var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID()))
if err != nil {
@ -749,38 +610,18 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
}
// Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID()))
return "finish", err
}
// Copy the default thread actions to a new ThreadActions map
func NewThreadActions() ThreadActions{
actions := ThreadActions{}
for k, v := range(BaseThreadActions) {
actions[k] = v
}
return actions
}
// Copy the defult thread handlers to a new ThreadAction map
func NewThreadHandlers() ThreadHandlers{
handlers := ThreadHandlers{}
for k, v := range(BaseThreadHandlers) {
handlers[k] = v
}
return handlers
}
// Default thread actions
var BaseThreadActions = ThreadActions{
"wait": ThreadWait,
"start": ThreadDefaultStart,
"finish": ThreadDefaultFinish,
"restore": ThreadDefaultRestore,
"start": ThreadStart,
"finish": ThreadFinish,
}
// Default thread signal handlers

@ -9,7 +9,7 @@ import (
)
type User struct {
SimpleLockable
Lockable
Granted time.Time
Pubkey *ecdsa.PublicKey
@ -18,7 +18,7 @@ type User struct {
}
type UserJSON struct {
SimpleLockableJSON
LockableJSON
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"`
@ -30,14 +30,14 @@ func (user *User) Type() NodeType {
}
func (user *User) Serialize() ([]byte, error) {
lockable_json := NewSimpleLockableJSON(&user.SimpleLockable)
lockable_json := NewLockableJSON(&user.Lockable)
pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey)
if err != nil {
return nil, err
}
return json.MarshalIndent(&UserJSON{
SimpleLockableJSON: lockable_json,
LockableJSON: lockable_json,
Granted: user.Granted,
Shared: user.Shared,
Pubkey: pubkey,
@ -68,7 +68,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error)
user := NewUser(j.Name, j.Granted, pubkey, j.Shared, j.Tags)
nodes[id] = &user
err = RestoreSimpleLockable(ctx, &user, j.SimpleLockableJSON, nodes)
err = RestoreLockable(ctx, &user.Lockable, j.LockableJSON, nodes)
if err != nil {
return nil, err
}
@ -79,7 +79,7 @@ func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error)
func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte, tags []string) User {
id := KeyID(pubkey)
return User{
SimpleLockable: NewSimpleLockable(id, name),
Lockable: NewLockable(id, name),
Granted: granted,
Pubkey: pubkey,
Shared: shared,