Reworked use/update to require a read/write context be initialized before starting, still need to figure out if brittle locking is the solution to potential deadlock, and implement if so

graph-rework-2
noah metz 2023-07-23 17:57:47 -06:00
parent 575912d56f
commit 8fb0cbc982
8 changed files with 391 additions and 298 deletions

@ -204,7 +204,8 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared)
new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"}) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"})
err := UpdateStates(ctx, server, NewLockMap(LockMap{ context := NewWriteContext(ctx)
err := UpdateStates(context, server, NewLockMap(LockMap{
server.ID(): LockInfo{ server.ID(): LockInfo{
Node: server, Node: server,
Resources: []string{"users"}, Resources: []string{"users"},
@ -213,7 +214,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
Node: &new_user, Node: &new_user,
Resources: []string{""}, Resources: []string{""},
}, },
}), func(context *WriteContext) error { }), func(context *StateContext) error {
server.Users[key_id] = &new_user server.Users[key_id] = &new_user
return nil return nil
}) })
@ -873,9 +874,10 @@ var gql_actions ThreadActions = ThreadActions{
}(server) }(server)
err = UpdateStates(ctx, server, NewLockMap( context := NewWriteContext(ctx)
err = UpdateStates(context, server, NewLockMap(
NewLockInfo(server, []string{"http_server"}), NewLockInfo(server, []string{"http_server"}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
server.tcp_listener = listener server.tcp_listener = listener
server.http_server = http_server server.http_server = http_server
return nil return nil
@ -885,9 +887,10 @@ var gql_actions ThreadActions = ThreadActions{
return "", err return "", err
} }
err = UseStates(ctx, server, NewLockMap( context = NewReadContext(ctx)
err = UseStates(context, server, NewLockMap(
NewLockInfo(server, []string{"signal"}), NewLockInfo(server, []string{"signal"}),
), func(context *ReadContext) error { ), func(context *StateContext) error {
return server.Signal(context, NewSignal("server_started")) return server.Signal(context, NewSignal("server_started"))
}) })
@ -897,14 +900,21 @@ var gql_actions ThreadActions = ThreadActions{
return "wait", nil return "wait", nil
}, },
"finish": func(ctx *Context, thread Thread) (string, error) {
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return "", ThreadFinish(ctx, thread)
},
} }
var gql_handlers ThreadHandlers = ThreadHandlers{ var gql_handlers ThreadHandlers = ThreadHandlers{
"child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { "child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal)
err := UpdateStates(ctx, thread, NewLockMap( context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}), NewLockInfo(thread, []string{"children"}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
sig, ok := signal.(IDSignal) sig, ok := signal.(IDSignal)
if ok == false { if ok == false {
ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST") ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST")
@ -945,19 +955,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{
return "wait", nil return "wait", nil
}, },
"abort": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { "abort": ThreadAbort,
ctx.Log.Logf("gql", "GQL_ABORT") "cancel": ThreadCancel,
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return ThreadAbort(ctx, thread, signal)
},
"cancel": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("gql", "GQL_CANCEL")
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return ThreadCancel(ctx, thread, signal)
},
} }

@ -29,14 +29,15 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
} }
var node Node = nil var node Node = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap( context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}), NewLockInfo(ctx.Server, []string{"children"}),
), func(context *ReadContext) (error){ ), func(context *StateContext) (error){
node = FindChild(context, ctx.User, ctx.Server, id) node = FindChild(context, ctx.User, ctx.Server, id)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"signal"})), func(context *ReadContext) error { return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error {
return node.Signal(context, AbortSignal) return node.Signal(context, AbortSignal)
}) })
}) })
@ -88,9 +89,10 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
} }
var signal GraphSignal var signal GraphSignal
err = UseStates(ctx.Context, ctx.User, NewLockMap( context := NewWriteContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}), NewLockInfo(ctx.Server, []string{"children"}),
), func(context *ReadContext) error { ), func(context *StateContext) error {
node := FindChild(context, ctx.User, ctx.Server, parent_id) node := FindChild(context, ctx.User, ctx.Server, parent_id)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
@ -101,7 +103,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
return err return err
} }
return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *ReadContext) error { return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error {
signal = NewStartChildSignal(child_id, action) signal = NewStartChildSignal(child_id, action)
return node.Signal(context, signal) return node.Signal(context, signal)
}) })

@ -42,24 +42,13 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
return id, nil return id, nil
} }
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(Node) node, ok := p.Source.(Node)
if ok == false || node == nil { if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Node") return nil, fmt.Errorf("Failed to cast source to Node")
} }
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"id"})), func(context *ReadContext) error {
return nil
})
if err != nil {
return nil, err
}
return node.ID(), nil return node.ID(), nil
} }
@ -76,7 +65,8 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
listen := "" listen := ""
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"listen"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = node.Listen listen = node.Listen
return nil return nil
}) })
@ -100,7 +90,8 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
} }
var parent Thread = nil var parent Thread = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"parent"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.Parent() parent = node.Parent()
return nil return nil
}) })
@ -124,7 +115,8 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
} }
var state string var state string
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"state"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.State() state = node.State()
return nil return nil
}) })
@ -148,7 +140,8 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
} }
var children []Thread = nil var children []Thread = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"children"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = node.Children() children = node.Children()
return nil return nil
}) })
@ -172,7 +165,8 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
} }
name := "" name := ""
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"name"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error {
name = node.Name() name = node.Name()
return nil return nil
}) })
@ -196,7 +190,8 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
} }
var requirements []Lockable = nil var requirements []Lockable = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"requirements"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = node.Requirements() requirements = node.Requirements()
return nil return nil
}) })
@ -220,7 +215,8 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
} }
var dependencies []Lockable = nil var dependencies []Lockable = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"dependencies"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = node.Dependencies() dependencies = node.Dependencies()
return nil return nil
}) })
@ -244,7 +240,8 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
} }
var owner Node = nil var owner Node = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"owner"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.Owner() owner = node.Owner()
return nil return nil
}) })
@ -268,7 +265,8 @@ func GQLThreadUsers(p graphql.ResolveParams) (interface{}, error) {
} }
var users []*User var users []*User
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"users"})), func(context *ReadContext) error { context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error {
users = make([]*User, len(node.Users)) users = make([]*User, len(node.Users))
i := 0 i := 0
for _, user := range(node.Users) { for _, user := range(node.Users) {

@ -18,48 +18,8 @@ import (
"encoding/base64" "encoding/base64"
) )
func TestGQLThread(t * testing.T) {
ctx := logTestContext(t, []string{})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql_t := &gql_t_r
t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers)
t2 := &t2_r
err = UpdateStates(ctx, gql_t, NewLockMap(
LockList([]Node{t1, t2}, []string{"parent"}),
NewLockInfo(gql_t, []string{"children"}),
), func(context *WriteContext) error {
i1 := NewParentThreadInfo(true, "start", "restore")
err := LinkThreads(context, gql_t, gql_t, t1, &i1)
if err != nil {
return err
}
i2 := NewParentThreadInfo(false, "start", "restore")
return LinkThreads(context, gql_t, gql_t, t2, &i2)
})
fatalErr(t, err)
go func(thread Thread){
time.Sleep(10*time.Millisecond)
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error {
return thread.Signal(context, CancelSignal)
})
fatalErr(t, err)
}(gql_t)
err = ThreadLoop(ctx, gql_t, "start")
fatalErr(t, err)
}
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{"policy", "mutex"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r l1 := &l1_r
@ -84,9 +44,10 @@ func TestGQLDBLoad(t * testing.T) {
gql := &gql_r gql := &gql_r
info := NewParentThreadInfo(true, "start", "restore") info := NewParentThreadInfo(true, "start", "restore")
err = UpdateStates(ctx, gql, NewLockMap( context := NewWriteContext(ctx)
err = UpdateStates(context, gql, NewLockMap(
NewLockInfo(gql, []string{"policies", "users"}), NewLockInfo(gql, []string{"policies", "users"}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
err := gql.AddPolicy(p1) err := gql.AddPolicy(p1)
if err != nil { if err != nil {
return err return err
@ -102,7 +63,8 @@ func TestGQLDBLoad(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, gql, NewLockInfo(gql, []string{"signal"}), func(context *ReadContext) error { context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error {
err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID())) err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID()))
if err != nil { if err != nil {
return nil return nil
@ -122,7 +84,8 @@ func TestGQLDBLoad(t * testing.T) {
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1")
err = UseStates(ctx, gql, LockList([]Node{gql, u1}, nil), func(context *ReadContext) error { context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize() ser1, err := gql.Serialize()
ser2, err := u1.Serialize() ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -135,14 +98,15 @@ func TestGQLDBLoad(t * testing.T) {
var t1_loaded *SimpleThread = nil var t1_loaded *SimpleThread = nil
var update_channel_2 chan GraphSignal var update_channel_2 chan GraphSignal
err = UseStates(ctx, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *ReadContext) error { context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize() ser, err := gql_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()]
child := gql_loaded.(Thread).Children()[0].(*SimpleThread) child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{})
err = UseMoreStates(context, gql, NewLockInfo(u_loaded, nil), func(context *ReadContext) error { err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize() ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
return err return err
@ -162,7 +126,7 @@ func TestGQLDBLoad(t * testing.T) {
} }
func TestGQLAuth(t * testing.T) { func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"test", "gql"}) ctx := logTestContext(t, []string{"policy", "mutex"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
@ -173,14 +137,16 @@ func TestGQLAuth(t * testing.T) {
gql_t := &gql_t_r gql_t := &gql_t_r
// p1 not written to DB, TODO: update write to follow links maybe // p1 not written to DB, TODO: update write to follow links maybe
err = UpdateStates(ctx, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *WriteContext) error { context := NewWriteContext(ctx)
err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error {
return gql_t.AddPolicy(p1) return gql_t.AddPolicy(p1)
}) })
done := make(chan error, 1) done := make(chan error, 1)
var update_channel chan GraphSignal var update_channel chan GraphSignal
err = UseStates(ctx, gql_t, NewLockInfo(gql_t, nil), func(context *ReadContext) error { context = NewReadContext(ctx)
err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error {
update_channel = UpdateChannel(gql_t, 10, NodeID{}) update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil return nil
}) })
@ -194,7 +160,8 @@ func TestGQLAuth(t * testing.T) {
case <-done: case <-done:
ctx.Log.Logf("test", "DONE") ctx.Log.Logf("test", "DONE")
} }
err := UseStates(ctx, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *ReadContext) error { context := NewReadContext(ctx)
err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error {
return thread.Signal(context, CancelSignal) return thread.Signal(context, CancelSignal)
}) })
fatalErr(t, err) fatalErr(t, err)

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"runtime/pprof" "runtime/pprof"
"runtime/debug"
"os" "os"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
@ -63,7 +64,7 @@ func testContext(t * testing.T) * Context {
func fatalErr(t * testing.T, err error) { func fatalErr(t * testing.T, err error) {
if err != nil { if err != nil {
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) debug.PrintStack()
t.Fatal(err) t.Fatal(err)
} }
} }

@ -216,7 +216,7 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
} }
// Assumed that lockable is already locked for signal // Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal) error { func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, signal) err := lockable.GraphNode.Signal(context, signal)
if err != nil { if err != nil {
return err return err
@ -224,10 +224,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseMoreStates(context, lockable, NewLockMap( err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"dependencies", "owner"}), NewLockInfo(lockable, []string{"dependencies", "owner"}),
LockList(lockable.requirements, []string{"signal"}), LockList(lockable.requirements, []string{"signal"}),
), func(context *ReadContext) error { ), func(context *StateContext) error {
owner_sent := false owner_sent := false
for _, dependency := range(lockable.dependencies) { for _, dependency := range(lockable.dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
@ -241,7 +241,7 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
if lockable.owner != nil && owner_sent == false { if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() { if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return UseMoreStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *ReadContext) error { return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error {
return lockable.owner.Signal(context, signal) return lockable.owner.Signal(context, signal)
}) })
} }
@ -249,10 +249,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
return nil return nil
}) })
case Down: case Down:
err = UseMoreStates(context, lockable, NewLockMap( err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}), NewLockInfo(lockable, []string{"requirements"}),
LockList(lockable.requirements, []string{"signal"}), LockList(lockable.requirements, []string{"signal"}),
), func(context *ReadContext) error { ), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) { for _, requirement := range(lockable.requirements) {
err := requirement.Signal(context, signal) err := requirement.Signal(context, signal)
if err != nil { if err != nil {
@ -270,8 +270,13 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
} }
// Removes requirement as a requirement from lockable // Removes requirement as a requirement from lockable
// Requires lockable and requirement be locked for write // Continues the write context with princ, getting requirents for lockable and dependencies for requirement
func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) error { // Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error {
return UpdateStates(context, princ, LockMap{
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found Node = nil var found Node = nil
for _, req := range(lockable.Requirements()) { for _, req := range(lockable.Requirements()) {
if requirement.ID() == req.ID() { if requirement.ID() == req.ID() {
@ -288,11 +293,13 @@ func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) err
lockable.RemoveRequirement(requirement) lockable.RemoveRequirement(requirement)
return nil return nil
})
} }
// Link requirements as requirements to lockable // Link requirements as requirements to lockable
// Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks // Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements
func LinkLockables(context *WriteContext, princ Node, lockable Lockable, requirements []Lockable) error { // Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error {
if lockable == nil { if lockable == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
} }
@ -318,10 +325,10 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require
found[requirement.ID()] = true found[requirement.ID()] = true
} }
return UpdateMoreStates(context, princ, NewLockMap( return UpdateStates(context, princ, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}), NewLockInfo(lockable, []string{"requirements"}),
LockList(requirements, []string{"dependencies"}), LockList(requirements, []string{"dependencies"}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
// Check that all the requirements can be added // Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it // If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) { for _, requirement := range(requirements) {
@ -365,13 +372,13 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require
} }
// Must be called withing update context // Must be called withing update context
func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool { func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool {
for _, c := range(cur.Requirements()) { for _, c := range(cur.Requirements()) {
if c.ID() == r.ID() { if c.ID() == r.ID() {
return true return true
} }
is_requirement := false is_requirement := false
UpdateMoreStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *WriteContext) error { UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error {
is_requirement = checkIfRequirement(context, cur, c) is_requirement = checkIfRequirement(context, cur, c)
return nil return nil
}) })
@ -386,7 +393,7 @@ func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool {
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable) error { func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error {
if to_lock == nil { if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided")
} }
@ -406,10 +413,10 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
return nil return nil
} }
return UpdateMoreStates(context, new_owner, NewLockMap( return UpdateStates(context, new_owner, NewLockMap(
LockList(to_lock, []string{"lock"}), LockList(to_lock, []string{"lock"}),
NewLockInfo(new_owner, []string{}), NewLockInfo(new_owner, []string{}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) { for _, req := range(to_lock) {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID())
@ -426,7 +433,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
if owner.ID() == new_owner.ID() { if owner.ID() == new_owner.ID() {
continue continue
} else { } else {
err := UpdateMoreStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *WriteContext)(error){ err := UpdateStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *StateContext)(error){
return LockLockables(context, req.Requirements(), req) return LockLockables(context, req.Requirements(), req)
}) })
if err != nil { if err != nil {
@ -464,7 +471,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
} }
func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lockable) error { func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error {
if to_unlock == nil { if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
} }
@ -484,10 +491,10 @@ func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lock
return nil return nil
} }
return UpdateMoreStates(context, old_owner, NewLockMap( return UpdateStates(context, old_owner, NewLockMap(
LockList(to_unlock, []string{"lock"}), LockList(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, []string{}), NewLockInfo(old_owner, []string{}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) { for _, req := range(to_unlock) {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID())

@ -74,7 +74,7 @@ type Node interface {
RemovePolicy(Policy) error RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Signal(context *ReadContext, signal GraphSignal) error Signal(context *StateContext, signal GraphSignal) error
// Register a channel to receive updates sent to the node // Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal) RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node // Unregister a channel from receiving updates sent to the node
@ -193,7 +193,7 @@ func (node * GraphNode) Type() NodeType {
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update // Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it // send it a notification that it was closed and then close it
func (node * GraphNode) Signal(context *ReadContext, signal GraphSignal) error { func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
node.listeners_lock.Lock() node.listeners_lock.Lock()
defer node.listeners_lock.Unlock() defer node.listeners_lock.Unlock()
@ -293,20 +293,18 @@ func getNodeBytes(node Node) ([]byte, error) {
} }
// Write multiple nodes to the database in a single transaction // Write multiple nodes to the database in a single transaction
func WriteNodes(context *WriteContext) error { func WriteNodes(context *StateContext) error {
if context == nil { err := ValidateStateContext(context, "write", true)
return fmt.Errorf("Cannot write nil to DB") if err != nil {
} return err
if context.Locked == nil {
return fmt.Errorf("Cannot write nil map to DB")
} }
context.Graph.Log.Logf("db", "DB_WRITES: %d", len(context.Locked)) context.Graph.Log.Logf("db", "DB_WRITES: %d", len(context.Locked))
serialized_bytes := make([][]byte, len(context.Locked)) serialized_bytes := make([][]byte, len(context.Locked))
serialized_ids := make([][]byte, len(context.Locked)) serialized_ids := make([][]byte, len(context.Locked))
i := 0 i := 0
for _, lock := range(context.Locked) { for _, node := range(context.Locked) {
node := lock.Node
node_bytes, err := getNodeBytes(node) node_bytes, err := getNodeBytes(node)
context.Graph.Log.Logf("db", "DB_WRITE: %+v", node) context.Graph.Log.Logf("db", "DB_WRITE: %+v", node)
if err != nil { if err != nil {
@ -321,7 +319,7 @@ func WriteNodes(context *WriteContext) error {
i++ i++
} }
err := context.Graph.DB.Update(func(txn *badger.Txn) error { return context.Graph.DB.Update(func(txn *badger.Txn) error {
for i, id := range(serialized_ids) { for i, id := range(serialized_ids) {
err := txn.Set(id, serialized_bytes[i]) err := txn.Set(id, serialized_bytes[i])
if err != nil { if err != nil {
@ -330,8 +328,6 @@ func WriteNodes(context *WriteContext) error {
} }
return nil return nil
}) })
return err
} }
// Get the bytes associates with `id` from the database after unwrapping the header, or error // Get the bytes associates with `id` from the database after unwrapping the header, or error
@ -450,17 +446,54 @@ type LockInfo struct {
type LockMap map[NodeID]LockInfo type LockMap map[NodeID]LockInfo
type ReadContext struct { type StateContext struct {
Type string
Graph *Context Graph *Context
Locked LockMap Permissions map[NodeID]LockMap
Locked NodeMap
Started bool
Finished bool
} }
type ReadFn func(*ReadContext)(error)
type WriteContext struct { func ValidateStateContext(context *StateContext, Type string, Finished bool) error {
Graph *Context if context == nil {
Locked LockMap return fmt.Errorf("context is nil")
}
if context.Finished != Finished {
return fmt.Errorf("context in wrong Finished state")
}
if context.Type != Type {
return fmt.Errorf("%s is not a %s context", context.Type, Type)
}
if context.Locked == nil || context.Graph == nil || context.Permissions == nil {
return fmt.Errorf("context is not initialized correctly")
}
return nil
}
func NewReadContext(ctx *Context) *StateContext {
return &StateContext{
Type: "read",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Started: false,
Finished: false,
}
}
func NewWriteContext(ctx *Context) *StateContext {
return &StateContext{
Type: "write",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Started: false,
Finished: false,
}
} }
type WriteFn func(*WriteContext)(error)
type StateFn func(*StateContext)(error)
func del[K comparable](list []K, val K) []K { func del[K comparable](list []K, val K) []K {
idx := -1 idx := -1
@ -478,146 +511,211 @@ func del[K comparable](list []K, val K) []K {
return list[:len(list)-1] return list[:len(list)-1]
} }
// Start a read context for node under ctx for the resources specified in init_nodes, then run nodes_fn // Add nodes to an existing read context and call nodes_fn with new_nodes locked for read
func UseStates(ctx *Context, node Node, nodes LockMap, read_fn ReadFn) error { // Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read
context := &ReadContext{ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
Graph: ctx, if princ == nil || new_nodes == nil || state_fn == nil {
Locked: LockMap{}, return fmt.Errorf("nil passed to UseStates")
} }
return UseMoreStates(context, node, nodes, read_fn)
err := ValidateStateContext(context, "read", false)
if err != nil {
return err
} }
// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read final := false
// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read if context.Started == false {
func UseMoreStates(context *ReadContext, node Node, new_nodes LockMap, read_fn ReadFn) error { context.Started = true
locked_nodes := []Node{} final = true
}
new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String())
princ.RLock()
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
new_permissions := LockMap{} new_permissions := LockMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) { for _, request := range(new_nodes) {
id := request.Node.ID() node := request.Node
new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
if id != princ.ID() {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.RLock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) { for _, resource := range(request.Resources) {
// If the permission for this resource is already granted, continue
current_permissions, exists := context.Locked[id]
if exists == true {
already_granted := false already_granted := false
for _, r := range(current_permissions.Resources) { for _, granted := range(node_permissions.Resources) {
if r == resource { if resource == granted {
already_granted = true already_granted = true
break
}
}
if already_granted == true {
continue
} }
} }
err := request.Node.Allowed("read", resource, node) if already_granted == false {
err := node.Allowed("read", resource, princ)
if err != nil { if err != nil {
return err context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock()
} }
return err
tmp := new_permissions[id]
tmp.Resources = append(tmp.Resources, resource)
new_permissions[id] = tmp
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource)
req_perms, exists := new_permissions[id]
if exists == true {
cur_perms, already_locked := context.Locked[id]
if already_locked == false {
request.Node.RLock()
context.Locked[id] = req_perms
locked_nodes = append(locked_nodes, request.Node)
} else { } else {
cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource)
} }
} }
new_permissions[id] = node_permissions
} }
err := read_fn(context) for _, node := range(new_locks) {
context.Locked[node.ID()] = node
for _, request := range(new_permissions) {
cur_perms := context.Locked[request.Node.ID()]
new_perms := cur_perms.Resources
for _, resource := range(cur_perms.Resources) {
new_perms = del(new_perms, resource)
}
cur_perms.Resources = new_perms
context.Locked[request.Node.ID()] = cur_perms
} }
for _, node := range(locked_nodes) { context.Permissions[princ.ID()] = new_permissions
err = state_fn(context)
context.Permissions[princ.ID()] = princ_permissions
for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String())
delete(context.Locked, node.ID()) delete(context.Locked, node.ID())
node.RUnlock() node.RUnlock()
} }
if final == true {
context.Finished = true
}
return err return err
} }
// Initiate a write context for nodes and call nodes_fn with nodes locked for read // Add nodes to an existing write context and call nodes_fn with nodes locked for read
func UpdateStates(ctx *Context, node Node, nodes LockMap, write_fn WriteFn) error { // If context is nil
context := &WriteContext{ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
Graph: ctx, if princ == nil || new_nodes == nil || state_fn == nil {
Locked: LockMap{}, return fmt.Errorf("nil passed to UpdateStates")
} }
err := UpdateMoreStates(context, node, nodes, write_fn)
if err == nil { err := ValidateStateContext(context, "write", false)
err = WriteNodes(context) if err != nil {
return err
} }
for _, lock := range(context.Locked) { final := false
lock.Node.Unlock() if context.Started == false {
context.Started = true
final = true
} }
return err new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String())
princ.Lock()
} }
// Add nodes to an existing write context and call nodes_fn with nodes locked for read princ_permissions, princ_exists := context.Permissions[princ.ID()]
func UpdateMoreStates(context *WriteContext, node Node, new_nodes LockMap, write_fn WriteFn) error {
new_permissions := LockMap{} new_permissions := LockMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) { for _, request := range(new_nodes) {
id := request.Node.ID() node := request.Node
new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}} if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
if id != princ.ID() {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.Lock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) { for _, resource := range(request.Resources) {
current_permissions, exists := context.Locked[id]
if exists == true {
already_granted := false already_granted := false
for _, r := range(current_permissions.Resources) { for _, granted := range(node_permissions.Resources) {
if r == resource { if resource == granted {
already_granted = true already_granted = true
break
}
}
if already_granted == true {
continue
} }
} }
err := request.Node.Allowed("write", resource, node) if already_granted == false {
err := node.Allowed("write", resource, princ)
if err != nil { if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock()
}
return err return err
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource)
}
}
new_permissions[id] = node_permissions
}
tmp := new_permissions[id] for _, node := range(new_locks) {
tmp.Resources = append(tmp.Resources, resource) context.Locked[node.ID()] = node
new_permissions[id] = tmp
} }
req_perms, exists := new_permissions[id] context.Permissions[princ.ID()] = new_permissions
if exists == true {
cur_perms, already_locked := context.Locked[id] err = state_fn(context)
if already_locked == false {
request.Node.Lock() if final == true {
context.Locked[id] = req_perms context.Finished = true
} else { if err == nil {
cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...) err = WriteNodes(context)
context.Locked[id] = cur_perms
} }
for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.Unlock()
} }
} }
return write_fn(context) return err
} }
// Create a new channel with a buffer the size of buffer, and register it to node with the id // Create a new channel with a buffer the size of buffer, and register it to node with the id

@ -10,7 +10,7 @@ import (
) )
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) error { func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, signal) err := thread.SimpleLockable.Signal(context, signal)
if err != nil { if err != nil {
return err return err
@ -18,9 +18,9 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseMoreStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *ReadContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil { if thread.parent != nil {
return UseMoreStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *ReadContext) error { return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error {
return thread.parent.Signal(context, signal) return thread.parent.Signal(context, signal)
}) })
} else { } else {
@ -28,10 +28,10 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err
} }
}) })
case Down: case Down:
err = UseMoreStates(context, thread, NewLockMap( err = UseStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}), NewLockInfo(thread, []string{"children"}),
LockList(thread.children, []string{"signal"}), LockList(thread.children, []string{"signal"}),
), func(context *ReadContext) error { ), func(context *StateContext) error {
for _, child := range(thread.children) { for _, child := range(thread.children) {
err := child.Signal(context, signal) err := child.Signal(context, signal)
if err != nil { if err != nil {
@ -169,15 +169,15 @@ func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error {
return nil return nil
} }
func checkIfChild(context *WriteContext, target Thread, cur Thread) bool { func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
for _, child := range(cur.Children()) { for _, child := range(cur.Children()) {
if child.ID() == target.ID() { if child.ID() == target.ID() {
return true return true
} }
is_child := false is_child := false
UpdateMoreStates(context, cur, NewLockMap( UpdateStates(context, cur, NewLockMap(
NewLockInfo(child, []string{"children"}), NewLockInfo(child, []string{"children"}),
), func(context *WriteContext) error { ), func(context *StateContext) error {
is_child = checkIfChild(context, target, child) is_child = checkIfChild(context, target, child)
return nil return nil
}) })
@ -189,7 +189,9 @@ func checkIfChild(context *WriteContext, target Thread, cur Thread) bool {
return false return false
} }
func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, info ThreadInfo) error { // Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error {
if context == nil || thread == nil || child == nil { if context == nil || thread == nil || child == nil {
return fmt.Errorf("invalid input") return fmt.Errorf("invalid input")
} }
@ -198,7 +200,10 @@ func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread,
return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
} }
return UpdateMoreStates(context, princ, LockList([]Node{child, thread}, []string{"parent", "children"}), func(context *WriteContext) error { return UpdateStates(context, princ, LockMap{
child.ID(): LockInfo{Node: child, Resources: []string{"parent"}},
thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error {
if child.Parent() != nil { if child.Parent() != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
} }
@ -449,7 +454,7 @@ func NewSimpleThread(id NodeID, name string, state_name string, info_type reflec
} }
// Requires the read permission of threads children // Requires the read permission of threads children
func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Thread { func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread {
if thread == nil { if thread == nil {
panic("cannot recurse through nil") panic("cannot recurse through nil")
} }
@ -459,7 +464,7 @@ func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Threa
for _, child := range thread.Children() { for _, child := range thread.Children() {
var result Thread = nil var result Thread = nil
UseMoreStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *ReadContext) error { UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, child, id) result = FindChild(context, princ, child, id)
return nil return nil
}) })
@ -485,7 +490,7 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
}(child) }(child)
} }
// Main Loop for Threads // Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, thread Thread, first_action string) error { func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
// Start the thread, error if double-started // Start the thread, error if double-started
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
@ -515,18 +520,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
return err return err
} }
err = UpdateStates(ctx, thread, NewLockInfo(thread, []string{"state"}), func(context *WriteContext) error {
err := thread.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, []Lockable{thread}, thread)
})
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_UNLOCK_ERR: %e", err)
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID()) ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID())
@ -578,13 +573,16 @@ func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lock
return false return false
} }
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error { func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error {
return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"children"}), func(context *WriteContext) error { context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
child := thread.Child(signal.ID) child := thread.Child(signal.ID)
if child == nil { if child == nil {
return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID()) return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID())
} }
return UpdateMoreStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *WriteContext) error { return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error {
info := thread.ChildInfo(signal.ID).(*ParentThreadInfo) info := thread.ChildInfo(signal.ID).(*ParentThreadInfo)
info.Start = true info.Start = true
@ -595,11 +593,14 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro
}) })
} }
// Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread Thread) { func ThreadRestore(ctx * Context, thread Thread) {
UpdateStates(ctx, thread, NewLockMap( context := NewWriteContext(ctx)
UpdateStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}), NewLockInfo(thread, []string{"children"}),
LockList(thread.Children(), []string{"start"}), LockList(thread.Children(), []string{"start"}),
), func(context *WriteContext)(error) { ), func(context *StateContext)(error) {
for _, child := range(thread.Children()) { for _, child := range(thread.Children()) {
should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent()
ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run)
@ -612,18 +613,15 @@ func ThreadRestore(ctx * Context, thread Thread) {
}) })
} }
// Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context
func ThreadStart(ctx * Context, thread Thread) error { func ThreadStart(ctx * Context, thread Thread) error {
return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"start", "lock"}), func(context *WriteContext) error { context := NewWriteContext(ctx)
owner_id := NodeID{} return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
if thread.Owner() != nil {
owner_id = thread.Owner().ID()
}
if owner_id != thread.ID() {
err := LockLockables(context, []Lockable{thread}, thread) err := LockLockables(context, []Lockable{thread}, thread)
if err != nil { if err != nil {
return err return err
} }
}
return thread.SetState("started") return thread.SetState("started")
}) })
} }
@ -657,7 +655,8 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
} }
case <- thread.Timeout(): case <- thread.Timeout():
timeout_action := "" timeout_action := ""
err := UpdateStates(ctx, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *WriteContext) error { context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.TimeoutAction() timeout_action = thread.TimeoutAction()
thread.ClearTimeout() thread.ClearTimeout()
return nil return nil
@ -671,27 +670,46 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
} }
} }
func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String())
return "", ThreadFinish(ctx, thread)
}
func ThreadFinish(ctx *Context, thread Thread) error {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, []Lockable{thread}, thread)
})
}
var ThreadAbortedError = errors.New("Thread aborted by signal") var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread abort is to return a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error { context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewStatusSignal("aborted", thread.ID())) return thread.Signal(context, NewStatusSignal("aborted", thread.ID()))
}) })
if err != nil { if err != nil {
return "", err return "", err
} }
return "", ThreadAbortedError return "finish", ThreadAbortedError
} }
// Default thread cancel is to finish the thread // Default thread action for "cancel", sends a signal and returns no error
func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error { context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewSignal("cancelled")) return thread.Signal(context, NewSignal("cancelled"))
}) })
return "", err return "finish", err
} }
// Copy the default thread actions to a new ThreadActions map
func NewThreadActions() ThreadActions{ func NewThreadActions() ThreadActions{
actions := ThreadActions{} actions := ThreadActions{}
for k, v := range(BaseThreadActions) { for k, v := range(BaseThreadActions) {
@ -701,6 +719,7 @@ func NewThreadActions() ThreadActions{
return actions return actions
} }
// Copy the defult thread handlers to a new ThreadAction map
func NewThreadHandlers() ThreadHandlers{ func NewThreadHandlers() ThreadHandlers{
handlers := ThreadHandlers{} handlers := ThreadHandlers{}
for k, v := range(BaseThreadHandlers) { for k, v := range(BaseThreadHandlers) {
@ -710,12 +729,15 @@ func NewThreadHandlers() ThreadHandlers{
return handlers return handlers
} }
// Default thread actions
var BaseThreadActions = ThreadActions{ var BaseThreadActions = ThreadActions{
"wait": ThreadWait, "wait": ThreadWait,
"start": ThreadDefaultStart, "start": ThreadDefaultStart,
"finish": ThreadDefaultFinish,
"restore": ThreadDefaultRestore, "restore": ThreadDefaultRestore,
} }
// Default thread signal handlers
var BaseThreadHandlers = ThreadHandlers{ var BaseThreadHandlers = ThreadHandlers{
"abort": ThreadAbort, "abort": ThreadAbort,
"cancel": ThreadCancel, "cancel": ThreadCancel,