Reworked use/update to require a read/write context be initialized before starting, still need to figure out if brittle locking is the solution to potential deadlock, and implement if so

graph-rework-2
noah metz 2023-07-23 17:57:47 -06:00
parent 575912d56f
commit 8fb0cbc982
8 changed files with 391 additions and 298 deletions

@ -204,7 +204,8 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared)
new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared, []string{"gql"})
err := UpdateStates(ctx, server, NewLockMap(LockMap{
context := NewWriteContext(ctx)
err := UpdateStates(context, server, NewLockMap(LockMap{
server.ID(): LockInfo{
Node: server,
Resources: []string{"users"},
@ -213,7 +214,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
Node: &new_user,
Resources: []string{""},
},
}), func(context *WriteContext) error {
}), func(context *StateContext) error {
server.Users[key_id] = &new_user
return nil
})
@ -873,9 +874,10 @@ var gql_actions ThreadActions = ThreadActions{
}(server)
err = UpdateStates(ctx, server, NewLockMap(
context := NewWriteContext(ctx)
err = UpdateStates(context, server, NewLockMap(
NewLockInfo(server, []string{"http_server"}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
server.tcp_listener = listener
server.http_server = http_server
return nil
@ -885,9 +887,10 @@ var gql_actions ThreadActions = ThreadActions{
return "", err
}
err = UseStates(ctx, server, NewLockMap(
context = NewReadContext(ctx)
err = UseStates(context, server, NewLockMap(
NewLockInfo(server, []string{"signal"}),
), func(context *ReadContext) error {
), func(context *StateContext) error {
return server.Signal(context, NewSignal("server_started"))
})
@ -897,14 +900,21 @@ var gql_actions ThreadActions = ThreadActions{
return "wait", nil
},
"finish": func(ctx *Context, thread Thread) (string, error) {
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return "", ThreadFinish(ctx, thread)
},
}
var gql_handlers ThreadHandlers = ThreadHandlers{
"child_linked": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal)
err := UpdateStates(ctx, thread, NewLockMap(
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
sig, ok := signal.(IDSignal)
if ok == false {
ctx.Log.Logf("gql", "GQL_THREAD_NODE_LINKED_BAD_CAST")
@ -945,19 +955,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{
return "wait", nil
},
"abort": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("gql", "GQL_ABORT")
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return ThreadAbort(ctx, thread, signal)
},
"cancel": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
ctx.Log.Logf("gql", "GQL_CANCEL")
server := thread.(*GQLThread)
server.http_server.Shutdown(context.TODO())
server.http_done.Wait()
return ThreadCancel(ctx, thread, signal)
},
"abort": ThreadAbort,
"cancel": ThreadCancel,
}

@ -29,14 +29,15 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
}
var node Node = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}),
), func(context *ReadContext) (error){
), func(context *StateContext) (error){
node = FindChild(context, ctx.User, ctx.Server, id)
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"signal"})), func(context *ReadContext) error {
return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error {
return node.Signal(context, AbortSignal)
})
})
@ -88,9 +89,10 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
}
var signal GraphSignal
err = UseStates(ctx.Context, ctx.User, NewLockMap(
context := NewWriteContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockMap(
NewLockInfo(ctx.Server, []string{"children"}),
), func(context *ReadContext) error {
), func(context *StateContext) error {
node := FindChild(context, ctx.User, ctx.Server, parent_id)
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
@ -101,7 +103,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
return err
}
return UseMoreStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *ReadContext) error {
return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error {
signal = NewStartChildSignal(child_id, action)
return node.Signal(context, signal)
})

@ -42,24 +42,13 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
return id, nil
}
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(Node)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Node")
}
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"id"})), func(context *ReadContext) error {
return nil
})
if err != nil {
return nil, err
}
return node.ID(), nil
}
@ -76,7 +65,8 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
listen := ""
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"listen"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = node.Listen
return nil
})
@ -100,7 +90,8 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
}
var parent Thread = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"parent"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.Parent()
return nil
})
@ -124,7 +115,8 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
}
var state string
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"state"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.State()
return nil
})
@ -148,7 +140,8 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
}
var children []Thread = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"children"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = node.Children()
return nil
})
@ -172,7 +165,8 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
}
name := ""
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"name"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error {
name = node.Name()
return nil
})
@ -196,7 +190,8 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
}
var requirements []Lockable = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"requirements"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = node.Requirements()
return nil
})
@ -220,7 +215,8 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
}
var dependencies []Lockable = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"dependencies"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = node.Dependencies()
return nil
})
@ -244,7 +240,8 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
}
var owner Node = nil
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"owner"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.Owner()
return nil
})
@ -268,7 +265,8 @@ func GQLThreadUsers(p graphql.ResolveParams) (interface{}, error) {
}
var users []*User
err = UseStates(ctx.Context, ctx.User, NewLockMap(NewLockInfo(node, []string{"users"})), func(context *ReadContext) error {
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error {
users = make([]*User, len(node.Users))
i := 0
for _, user := range(node.Users) {

@ -18,48 +18,8 @@ import (
"encoding/base64"
)
func TestGQLThread(t * testing.T) {
ctx := logTestContext(t, []string{})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql_t := &gql_t_r
t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers)
t2 := &t2_r
err = UpdateStates(ctx, gql_t, NewLockMap(
LockList([]Node{t1, t2}, []string{"parent"}),
NewLockInfo(gql_t, []string{"children"}),
), func(context *WriteContext) error {
i1 := NewParentThreadInfo(true, "start", "restore")
err := LinkThreads(context, gql_t, gql_t, t1, &i1)
if err != nil {
return err
}
i2 := NewParentThreadInfo(false, "start", "restore")
return LinkThreads(context, gql_t, gql_t, t2, &i2)
})
fatalErr(t, err)
go func(thread Thread){
time.Sleep(10*time.Millisecond)
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error {
return thread.Signal(context, CancelSignal)
})
fatalErr(t, err)
}(gql_t)
err = ThreadLoop(ctx, gql_t, "start")
fatalErr(t, err)
}
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{})
ctx := logTestContext(t, []string{"policy", "mutex"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -84,9 +44,10 @@ func TestGQLDBLoad(t * testing.T) {
gql := &gql_r
info := NewParentThreadInfo(true, "start", "restore")
err = UpdateStates(ctx, gql, NewLockMap(
context := NewWriteContext(ctx)
err = UpdateStates(context, gql, NewLockMap(
NewLockInfo(gql, []string{"policies", "users"}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
err := gql.AddPolicy(p1)
if err != nil {
return err
@ -102,7 +63,8 @@ func TestGQLDBLoad(t * testing.T) {
})
fatalErr(t, err)
err = UseStates(ctx, gql, NewLockInfo(gql, []string{"signal"}), func(context *ReadContext) error {
context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error {
err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID()))
if err != nil {
return nil
@ -122,7 +84,8 @@ func TestGQLDBLoad(t * testing.T) {
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1")
err = UseStates(ctx, gql, LockList([]Node{gql, u1}, nil), func(context *ReadContext) error {
context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -135,14 +98,15 @@ func TestGQLDBLoad(t * testing.T) {
var t1_loaded *SimpleThread = nil
var update_channel_2 chan GraphSignal
err = UseStates(ctx, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *ReadContext) error {
context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql_loaded, []string{"users", "children"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()]
child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{})
err = UseMoreStates(context, gql, NewLockInfo(u_loaded, nil), func(context *ReadContext) error {
err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
@ -162,7 +126,7 @@ func TestGQLDBLoad(t * testing.T) {
}
func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"test", "gql"})
ctx := logTestContext(t, []string{"policy", "mutex"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
@ -173,14 +137,16 @@ func TestGQLAuth(t * testing.T) {
gql_t := &gql_t_r
// p1 not written to DB, TODO: update write to follow links maybe
err = UpdateStates(ctx, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *WriteContext) error {
context := NewWriteContext(ctx)
err = UpdateStates(context, gql_t, NewLockInfo(gql_t, []string{"policies"}), func(context *StateContext) error {
return gql_t.AddPolicy(p1)
})
done := make(chan error, 1)
var update_channel chan GraphSignal
err = UseStates(ctx, gql_t, NewLockInfo(gql_t, nil), func(context *ReadContext) error {
context = NewReadContext(ctx)
err = UseStates(context, gql_t, NewLockInfo(gql_t, nil), func(context *StateContext) error {
update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil
})
@ -194,7 +160,8 @@ func TestGQLAuth(t * testing.T) {
case <-done:
ctx.Log.Logf("test", "DONE")
}
err := UseStates(ctx, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *ReadContext) error {
context := NewReadContext(ctx)
err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error {
return thread.Signal(context, CancelSignal)
})
fatalErr(t, err)

@ -5,6 +5,7 @@ import (
"fmt"
"time"
"runtime/pprof"
"runtime/debug"
"os"
badger "github.com/dgraph-io/badger/v3"
)
@ -63,7 +64,7 @@ func testContext(t * testing.T) * Context {
func fatalErr(t * testing.T, err error) {
if err != nil {
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
debug.PrintStack()
t.Fatal(err)
}
}

@ -216,7 +216,7 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
}
// Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal) error {
func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, signal)
if err != nil {
return err
@ -224,10 +224,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
switch signal.Direction() {
case Up:
err = UseMoreStates(context, lockable, NewLockMap(
err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"dependencies", "owner"}),
LockList(lockable.requirements, []string{"signal"}),
), func(context *ReadContext) error {
), func(context *StateContext) error {
owner_sent := false
for _, dependency := range(lockable.dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
@ -241,7 +241,7 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return UseMoreStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *ReadContext) error {
return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error {
return lockable.owner.Signal(context, signal)
})
}
@ -249,10 +249,10 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
return nil
})
case Down:
err = UseMoreStates(context, lockable, NewLockMap(
err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}),
LockList(lockable.requirements, []string{"signal"}),
), func(context *ReadContext) error {
), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) {
err := requirement.Signal(context, signal)
if err != nil {
@ -270,8 +270,13 @@ func (lockable * SimpleLockable) Signal(context *ReadContext, signal GraphSignal
}
// Removes requirement as a requirement from lockable
// Requires lockable and requirement be locked for write
func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) error {
// Continues the write context with princ, getting requirents for lockable and dependencies for requirement
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func UnlinkLockables(context *StateContext, princ Node, lockable Lockable, requirement Lockable) error {
return UpdateStates(context, princ, LockMap{
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found Node = nil
for _, req := range(lockable.Requirements()) {
if requirement.ID() == req.ID() {
@ -288,11 +293,13 @@ func UnlinkLockables(ctx * Context, lockable Lockable, requirement Lockable) err
lockable.RemoveRequirement(requirement)
return nil
})
}
// Link requirements as requirements to lockable
// Requires lockable and requirements to be locked for write, nodes passed because requirement check recursively locks
func LinkLockables(context *WriteContext, princ Node, lockable Lockable, requirements []Lockable) error {
// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func LinkLockables(context *StateContext, princ Node, lockable Lockable, requirements []Lockable) error {
if lockable == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
}
@ -318,10 +325,10 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require
found[requirement.ID()] = true
}
return UpdateMoreStates(context, princ, NewLockMap(
return UpdateStates(context, princ, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}),
LockList(requirements, []string{"dependencies"}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
// Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) {
@ -365,13 +372,13 @@ func LinkLockables(context *WriteContext, princ Node, lockable Lockable, require
}
// Must be called withing update context
func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool {
func checkIfRequirement(context *StateContext, r Lockable, cur Lockable) bool {
for _, c := range(cur.Requirements()) {
if c.ID() == r.ID() {
return true
}
is_requirement := false
UpdateMoreStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *WriteContext) error {
UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error {
is_requirement = checkIfRequirement(context, cur, c)
return nil
})
@ -386,7 +393,7 @@ func checkIfRequirement(context *WriteContext, r Lockable, cur Lockable) bool {
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable) error {
func LockLockables(context *StateContext, to_lock []Lockable, new_owner Lockable) error {
if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided")
}
@ -406,10 +413,10 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
return nil
}
return UpdateMoreStates(context, new_owner, NewLockMap(
return UpdateStates(context, new_owner, NewLockMap(
LockList(to_lock, []string{"lock"}),
NewLockInfo(new_owner, []string{}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID())
@ -426,7 +433,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
if owner.ID() == new_owner.ID() {
continue
} else {
err := UpdateMoreStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *WriteContext)(error){
err := UpdateStates(context, new_owner, NewLockMap(NewLockInfo(owner, []string{"take_lock"})), func(context *StateContext)(error){
return LockLockables(context, req.Requirements(), req)
})
if err != nil {
@ -464,7 +471,7 @@ func LockLockables(context *WriteContext, to_lock []Lockable, new_owner Lockable
}
func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lockable) error {
func UnlockLockables(context *StateContext, to_unlock []Lockable, old_owner Lockable) error {
if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
}
@ -484,10 +491,10 @@ func UnlockLockables(context *WriteContext, to_unlock []Lockable, old_owner Lock
return nil
}
return UpdateMoreStates(context, old_owner, NewLockMap(
return UpdateStates(context, old_owner, NewLockMap(
LockList(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, []string{}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID())

@ -74,7 +74,7 @@ type Node interface {
RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Signal(context *ReadContext, signal GraphSignal) error
Signal(context *StateContext, signal GraphSignal) error
// Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node
@ -193,7 +193,7 @@ func (node * GraphNode) Type() NodeType {
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it
func (node * GraphNode) Signal(context *ReadContext, signal GraphSignal) error {
func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
node.listeners_lock.Lock()
defer node.listeners_lock.Unlock()
@ -293,20 +293,18 @@ func getNodeBytes(node Node) ([]byte, error) {
}
// Write multiple nodes to the database in a single transaction
func WriteNodes(context *WriteContext) error {
if context == nil {
return fmt.Errorf("Cannot write nil to DB")
}
if context.Locked == nil {
return fmt.Errorf("Cannot write nil map to DB")
func WriteNodes(context *StateContext) error {
err := ValidateStateContext(context, "write", true)
if err != nil {
return err
}
context.Graph.Log.Logf("db", "DB_WRITES: %d", len(context.Locked))
serialized_bytes := make([][]byte, len(context.Locked))
serialized_ids := make([][]byte, len(context.Locked))
i := 0
for _, lock := range(context.Locked) {
node := lock.Node
for _, node := range(context.Locked) {
node_bytes, err := getNodeBytes(node)
context.Graph.Log.Logf("db", "DB_WRITE: %+v", node)
if err != nil {
@ -321,7 +319,7 @@ func WriteNodes(context *WriteContext) error {
i++
}
err := context.Graph.DB.Update(func(txn *badger.Txn) error {
return context.Graph.DB.Update(func(txn *badger.Txn) error {
for i, id := range(serialized_ids) {
err := txn.Set(id, serialized_bytes[i])
if err != nil {
@ -330,8 +328,6 @@ func WriteNodes(context *WriteContext) error {
}
return nil
})
return err
}
// Get the bytes associates with `id` from the database after unwrapping the header, or error
@ -450,17 +446,54 @@ type LockInfo struct {
type LockMap map[NodeID]LockInfo
type ReadContext struct {
type StateContext struct {
Type string
Graph *Context
Locked LockMap
Permissions map[NodeID]LockMap
Locked NodeMap
Started bool
Finished bool
}
type ReadFn func(*ReadContext)(error)
type WriteContext struct {
Graph *Context
Locked LockMap
func ValidateStateContext(context *StateContext, Type string, Finished bool) error {
if context == nil {
return fmt.Errorf("context is nil")
}
if context.Finished != Finished {
return fmt.Errorf("context in wrong Finished state")
}
if context.Type != Type {
return fmt.Errorf("%s is not a %s context", context.Type, Type)
}
if context.Locked == nil || context.Graph == nil || context.Permissions == nil {
return fmt.Errorf("context is not initialized correctly")
}
return nil
}
type WriteFn func(*WriteContext)(error)
func NewReadContext(ctx *Context) *StateContext {
return &StateContext{
Type: "read",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Started: false,
Finished: false,
}
}
func NewWriteContext(ctx *Context) *StateContext {
return &StateContext{
Type: "write",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Started: false,
Finished: false,
}
}
type StateFn func(*StateContext)(error)
func del[K comparable](list []K, val K) []K {
idx := -1
@ -478,146 +511,211 @@ func del[K comparable](list []K, val K) []K {
return list[:len(list)-1]
}
// Start a read context for node under ctx for the resources specified in init_nodes, then run nodes_fn
func UseStates(ctx *Context, node Node, nodes LockMap, read_fn ReadFn) error {
context := &ReadContext{
Graph: ctx,
Locked: LockMap{},
}
return UseMoreStates(context, node, nodes, read_fn)
}
// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read
// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read
func UseMoreStates(context *ReadContext, node Node, new_nodes LockMap, read_fn ReadFn) error {
locked_nodes := []Node{}
func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UseStates")
}
err := ValidateStateContext(context, "read", false)
if err != nil {
return err
}
final := false
if context.Started == false {
context.Started = true
final = true
}
new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String())
princ.RLock()
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
new_permissions := LockMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) {
id := request.Node.ID()
new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}}
node := request.Node
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
if id != princ.ID() {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.RLock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
// If the permission for this resource is already granted, continue
current_permissions, exists := context.Locked[id]
if exists == true {
already_granted := false
for _, r := range(current_permissions.Resources) {
if r == resource {
for _, granted := range(node_permissions.Resources) {
if resource == granted {
already_granted = true
break
}
}
if already_granted == true {
continue
}
}
err := request.Node.Allowed("read", resource, node)
if already_granted == false {
err := node.Allowed("read", resource, princ)
if err != nil {
return err
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock()
}
tmp := new_permissions[id]
tmp.Resources = append(tmp.Resources, resource)
new_permissions[id] = tmp
return err
}
req_perms, exists := new_permissions[id]
if exists == true {
cur_perms, already_locked := context.Locked[id]
if already_locked == false {
request.Node.RLock()
context.Locked[id] = req_perms
locked_nodes = append(locked_nodes, request.Node)
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource)
} else {
cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...)
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource)
}
}
new_permissions[id] = node_permissions
}
err := read_fn(context)
for _, request := range(new_permissions) {
cur_perms := context.Locked[request.Node.ID()]
new_perms := cur_perms.Resources
for _, resource := range(cur_perms.Resources) {
new_perms = del(new_perms, resource)
}
cur_perms.Resources = new_perms
context.Locked[request.Node.ID()] = cur_perms
for _, node := range(new_locks) {
context.Locked[node.ID()] = node
}
for _, node := range(locked_nodes) {
context.Permissions[princ.ID()] = new_permissions
err = state_fn(context)
context.Permissions[princ.ID()] = princ_permissions
for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String())
delete(context.Locked, node.ID())
node.RUnlock()
}
if final == true {
context.Finished = true
}
return err
}
// Initiate a write context for nodes and call nodes_fn with nodes locked for read
func UpdateStates(ctx *Context, node Node, nodes LockMap, write_fn WriteFn) error {
context := &WriteContext{
Graph: ctx,
Locked: LockMap{},
// Add nodes to an existing write context and call nodes_fn with nodes locked for read
// If context is nil
func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UpdateStates")
}
err := UpdateMoreStates(context, node, nodes, write_fn)
if err == nil {
err = WriteNodes(context)
err := ValidateStateContext(context, "write", false)
if err != nil {
return err
}
for _, lock := range(context.Locked) {
lock.Node.Unlock()
final := false
if context.Started == false {
context.Started = true
final = true
}
return err
}
new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String())
princ.Lock()
}
// Add nodes to an existing write context and call nodes_fn with nodes locked for read
func UpdateMoreStates(context *WriteContext, node Node, new_nodes LockMap, write_fn WriteFn) error {
princ_permissions, princ_exists := context.Permissions[princ.ID()]
new_permissions := LockMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) {
id := request.Node.ID()
new_permissions[id] = LockInfo{Node: request.Node, Resources: []string{}}
node := request.Node
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
if id != princ.ID() {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.Lock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
current_permissions, exists := context.Locked[id]
if exists == true {
already_granted := false
for _, r := range(current_permissions.Resources) {
if r == resource {
for _, granted := range(node_permissions.Resources) {
if resource == granted {
already_granted = true
break
}
}
if already_granted == true {
continue
}
}
err := request.Node.Allowed("write", resource, node)
if already_granted == false {
err := node.Allowed("write", resource, princ)
if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock()
}
return err
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource)
}
}
new_permissions[id] = node_permissions
}
tmp := new_permissions[id]
tmp.Resources = append(tmp.Resources, resource)
new_permissions[id] = tmp
for _, node := range(new_locks) {
context.Locked[node.ID()] = node
}
req_perms, exists := new_permissions[id]
if exists == true {
cur_perms, already_locked := context.Locked[id]
if already_locked == false {
request.Node.Lock()
context.Locked[id] = req_perms
} else {
cur_perms.Resources = append(cur_perms.Resources, req_perms.Resources...)
context.Locked[id] = cur_perms
context.Permissions[princ.ID()] = new_permissions
err = state_fn(context)
if final == true {
context.Finished = true
if err == nil {
err = WriteNodes(context)
}
for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.Unlock()
}
}
return write_fn(context)
return err
}
// Create a new channel with a buffer the size of buffer, and register it to node with the id

@ -10,7 +10,7 @@ import (
)
// Assumed that thread is already locked for signal
func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) error {
func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, signal)
if err != nil {
return err
@ -18,9 +18,9 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err
switch signal.Direction() {
case Up:
err = UseMoreStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *ReadContext) error {
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil {
return UseMoreStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *ReadContext) error {
return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error {
return thread.parent.Signal(context, signal)
})
} else {
@ -28,10 +28,10 @@ func (thread *SimpleThread) Signal(context *ReadContext, signal GraphSignal) err
}
})
case Down:
err = UseMoreStates(context, thread, NewLockMap(
err = UseStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}),
LockList(thread.children, []string{"signal"}),
), func(context *ReadContext) error {
), func(context *StateContext) error {
for _, child := range(thread.children) {
err := child.Signal(context, signal)
if err != nil {
@ -169,15 +169,15 @@ func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error {
return nil
}
func checkIfChild(context *WriteContext, target Thread, cur Thread) bool {
func checkIfChild(context *StateContext, target Thread, cur Thread) bool {
for _, child := range(cur.Children()) {
if child.ID() == target.ID() {
return true
}
is_child := false
UpdateMoreStates(context, cur, NewLockMap(
UpdateStates(context, cur, NewLockMap(
NewLockInfo(child, []string{"children"}),
), func(context *WriteContext) error {
), func(context *StateContext) error {
is_child = checkIfChild(context, target, child)
return nil
})
@ -189,7 +189,9 @@ func checkIfChild(context *WriteContext, target Thread, cur Thread) bool {
return false
}
func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread, info ThreadInfo) error {
// Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread Thread, child Thread, info ThreadInfo) error {
if context == nil || thread == nil || child == nil {
return fmt.Errorf("invalid input")
}
@ -198,7 +200,10 @@ func LinkThreads(context *WriteContext, princ Node, thread Thread, child Thread,
return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
}
return UpdateMoreStates(context, princ, LockList([]Node{child, thread}, []string{"parent", "children"}), func(context *WriteContext) error {
return UpdateStates(context, princ, LockMap{
child.ID(): LockInfo{Node: child, Resources: []string{"parent"}},
thread.ID(): LockInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error {
if child.Parent() != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
}
@ -449,7 +454,7 @@ func NewSimpleThread(id NodeID, name string, state_name string, info_type reflec
}
// Requires the read permission of threads children
func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Thread {
func FindChild(context *StateContext, princ Node, thread Thread, id NodeID) Thread {
if thread == nil {
panic("cannot recurse through nil")
}
@ -459,7 +464,7 @@ func FindChild(context *ReadContext, princ Node, thread Thread, id NodeID) Threa
for _, child := range thread.Children() {
var result Thread = nil
UseMoreStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *ReadContext) error {
UseStates(context, princ, NewLockInfo(child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, child, id)
return nil
})
@ -485,7 +490,7 @@ func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
}(child)
}
// Main Loop for Threads
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
// Start the thread, error if double-started
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
@ -515,18 +520,8 @@ func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
return err
}
err = UpdateStates(ctx, thread, NewLockInfo(thread, []string{"state"}), func(context *WriteContext) error {
err := thread.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, []Lockable{thread}, thread)
})
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_UNLOCK_ERR: %e", err)
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID())
@ -578,13 +573,16 @@ func (thread * SimpleThread) AllowedToTakeLock(new_owner Lockable, lockable Lock
return false
}
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) error {
return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"children"}), func(context *WriteContext) error {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
child := thread.Child(signal.ID)
if child == nil {
return fmt.Errorf("%s is not a child of %s", signal.ID, thread.ID())
}
return UpdateMoreStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *WriteContext) error {
return UpdateStates(context, thread, NewLockInfo(child, []string{"start"}), func(context *StateContext) error {
info := thread.ChildInfo(signal.ID).(*ParentThreadInfo)
info.Start = true
@ -595,11 +593,14 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro
})
}
// Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread Thread) {
UpdateStates(ctx, thread, NewLockMap(
context := NewWriteContext(ctx)
UpdateStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}),
LockList(thread.Children(), []string{"start"}),
), func(context *WriteContext)(error) {
), func(context *StateContext)(error) {
for _, child := range(thread.Children()) {
should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent()
ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run)
@ -612,18 +613,15 @@ func ThreadRestore(ctx * Context, thread Thread) {
})
}
// Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context
func ThreadStart(ctx * Context, thread Thread) error {
return UpdateStates(ctx, thread, NewLockInfo(thread, []string{"start", "lock"}), func(context *WriteContext) error {
owner_id := NodeID{}
if thread.Owner() != nil {
owner_id = thread.Owner().ID()
}
if owner_id != thread.ID() {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, []Lockable{thread}, thread)
if err != nil {
return err
}
}
return thread.SetState("started")
})
}
@ -657,7 +655,8 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
}
case <- thread.Timeout():
timeout_action := ""
err := UpdateStates(ctx, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *WriteContext) error {
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewLockMap(NewLockInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.TimeoutAction()
thread.ClearTimeout()
return nil
@ -671,27 +670,46 @@ func ThreadWait(ctx * Context, thread Thread) (string, error) {
}
}
func ThreadDefaultFinish(ctx *Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_FINISH: %s", thread.ID().String())
return "", ThreadFinish(ctx, thread)
}
func ThreadFinish(ctx *Context, thread Thread) error {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewLockInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, []Lockable{thread}, thread)
})
}
var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread abort is to return a ThreadAbortedError
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error {
context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewStatusSignal("aborted", thread.ID()))
})
if err != nil {
return "", err
}
return "", ThreadAbortedError
return "finish", ThreadAbortedError
}
// Default thread cancel is to finish the thread
// Default thread action for "cancel", sends a signal and returns no error
func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
err := UseStates(ctx, thread, NewLockInfo(thread, []string{"signal"}), func(context *ReadContext) error {
context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewSignal("cancelled"))
})
return "", err
return "finish", err
}
// Copy the default thread actions to a new ThreadActions map
func NewThreadActions() ThreadActions{
actions := ThreadActions{}
for k, v := range(BaseThreadActions) {
@ -701,6 +719,7 @@ func NewThreadActions() ThreadActions{
return actions
}
// Copy the defult thread handlers to a new ThreadAction map
func NewThreadHandlers() ThreadHandlers{
handlers := ThreadHandlers{}
for k, v := range(BaseThreadHandlers) {
@ -710,12 +729,15 @@ func NewThreadHandlers() ThreadHandlers{
return handlers
}
// Default thread actions
var BaseThreadActions = ThreadActions{
"wait": ThreadWait,
"start": ThreadDefaultStart,
"finish": ThreadDefaultFinish,
"restore": ThreadDefaultRestore,
}
// Default thread signal handlers
var BaseThreadHandlers = ThreadHandlers{
"abort": ThreadAbort,
"cancel": ThreadCancel,