Fixed starting of gql server to correctly start children that should start

graph-rework-2
noah metz 2023-07-23 19:04:04 -06:00
parent 8fb0cbc982
commit 054fe3c0ec
5 changed files with 111 additions and 64 deletions

@ -809,7 +809,10 @@ var gql_actions ThreadActions = ThreadActions{
"restore": func(ctx * Context, thread Thread) (string, error) { "restore": func(ctx * Context, thread Thread) (string, error) {
// Start all the threads that should be "started" // Start all the threads that should be "started"
ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID()) ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID())
ThreadRestore(ctx, thread) err := ThreadRestore(ctx, thread)
if err != nil {
return "", err
}
return "start_server", nil return "start_server", nil
}, },
"start": func(ctx * Context, thread Thread) (string, error) { "start": func(ctx * Context, thread Thread) (string, error) {
@ -819,6 +822,21 @@ var gql_actions ThreadActions = ThreadActions{
return "", err return "", err
} }
// Start all the threads that have "Start" as true
context := NewWriteContext(ctx)
err = UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error {
for _, child := range(thread.Children()) {
info := thread.ChildInfo(child.ID()).(ParentInfo).Parent()
if info.Start == true {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s -> %s", thread.ID(), child.ID())
ChildGo(ctx, thread, child, info.StartAction)
}
}
return nil
})
})
return "start_server", nil return "start_server", nil
}, },
"start_server": func(ctx * Context, thread Thread) (string, error) { "start_server": func(ctx * Context, thread Thread) (string, error) {
@ -956,6 +974,6 @@ var gql_handlers ThreadHandlers = ThreadHandlers{
return "wait", nil return "wait", nil
}, },
"abort": ThreadAbort, "abort": ThreadAbort,
"cancel": ThreadCancel, "stop": ThreadStop,
} }

@ -3,7 +3,6 @@ package graphvent
import ( import (
"testing" "testing"
"time" "time"
"errors"
"net" "net"
"net/http" "net/http"
"io" "io"
@ -19,13 +18,17 @@ import (
) )
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"policy", "mutex"}) ctx := logTestContext(t, []string{"test", "signal", "thread"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r l1 := &l1_r
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r t1 := &t1_r
update_channel := UpdateChannel(t1, 10, NodeID{}) ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String())
listen_id := RandID()
ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String())
update_channel := UpdateChannel(t1, 10, listen_id)
u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
@ -34,28 +37,52 @@ func TestGQLDBLoad(t * testing.T) {
u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"}) u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared, []string{"gql"})
u1 := &u1_r u1 := &u1_r
ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String())
p1_r := NewSimplePolicy(RandID(), NewNodeActions(nil, []string{"enumerate"}))
p1 := &p1_r
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql := &gql_r gql := &gql_r
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String())
// Policy to allow gql to perform all action on all resources
p1_r := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
gql.ID(): NewNodeActions(nil, []string{"*"}),
})
p1 := &p1_r
p2_r := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
"signal": []string{"read"},
}, nil))
p2 := &p2_r
info := NewParentThreadInfo(true, "start", "restore")
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, gql, LockMap{
p1.ID(): LockInfo{p1, nil},
p2.ID(): LockInfo{p2, nil},
}, func(context *StateContext) error {
return nil
})
fatalErr(t, err)
ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String())
ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String())
err = AttachPolicies(ctx, gql, p1, p2)
fatalErr(t, err)
err = AttachPolicies(ctx, l1, p1, p2)
fatalErr(t, err)
err = AttachPolicies(ctx, t1, p1, p2)
fatalErr(t, err)
err = AttachPolicies(ctx, u1, p1, p2)
fatalErr(t, err)
info := NewParentThreadInfo(true, "start", "restore")
context = NewWriteContext(ctx)
err = UpdateStates(context, gql, NewLockMap( err = UpdateStates(context, gql, NewLockMap(
NewLockInfo(gql, []string{"policies", "users"}), NewLockInfo(gql, []string{"users"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
err := gql.AddPolicy(p1)
if err != nil {
return err
}
gql.Users[KeyID(&u1_key.PublicKey)] = u1 gql.Users[KeyID(&u1_key.PublicKey)] = u1
err = LinkThreads(context, gql, gql, t1, &info) err := LinkThreads(context, gql, gql, t1, &info)
if err != nil { if err != nil {
return err return err
} }
@ -69,20 +96,14 @@ func TestGQLDBLoad(t * testing.T) {
if err != nil { if err != nil {
return nil return nil
} }
return gql.Signal(context, CancelSignal) return gql.Signal(context, StopSignal)
}) })
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start") err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, ThreadAbortedError) { fatalErr(t, err)
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err)
} else if err != nil{
fatalErr(t, err)
} else {
ctx.Log.Logf("test", "Main thread cancelled by signal")
}
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") (*GraphTester)(t).WaitForValue(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel")
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
@ -105,28 +126,24 @@ func TestGQLDBLoad(t * testing.T) {
u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()]
child := gql_loaded.(Thread).Children()[0].(*SimpleThread) child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) update_channel_2 = UpdateChannel(t1_loaded, 10, RandID())
err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error { err = UseStates(context, gql, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize() ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
return err return err
}) })
gql_loaded.Signal(context, AbortSignal) gql_loaded.Signal(context, StopSignal)
return err return err
}) })
err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") err = ThreadLoop(ctx, gql_loaded.(Thread), "start")
if errors.Is(err, ThreadAbortedError) { fatalErr(t, err)
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
} else {
fatalErr(t, err)
}
(*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded")
} }
func TestGQLAuth(t * testing.T) { func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"policy", "mutex"}) ctx := logTestContext(t, []string{"policy"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
@ -162,7 +179,7 @@ func TestGQLAuth(t * testing.T) {
} }
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error { err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error {
return thread.Signal(context, CancelSignal) return thread.Signal(context, StopSignal)
}) })
fatalErr(t, err) fatalErr(t, err)
}(done, gql_t) }(done, gql_t)

@ -108,7 +108,7 @@ func (node *GraphNode) Allowed(action string, resource string, principal Node) e
return nil return nil
} }
for _, policy := range(node.policies) { for _, policy := range(node.policies) {
if policy.Allows(action, resource, principal) == true { if policy.Allows(resource, action, principal) == true {
return nil return nil
} }
} }
@ -200,11 +200,11 @@ func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error
closed := []NodeID{} closed := []NodeID{}
for id, listener := range node.listeners { for id, listener := range node.listeners {
context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %p", node.ID(), listener) context.Graph.Log.Logf("signal", "UPDATE_LISTENER %s: %s", node.ID(), id)
select { select {
case listener <- signal: case listener <- signal:
default: default:
context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %p", node.ID(), listener) context.Graph.Log.Logf("signal", "CLOSED_LISTENER %s: %s", node.ID(), id)
go func(node Node, listener chan GraphSignal) { go func(node Node, listener chan GraphSignal) {
listener <- NewDirectSignal("listener_closed") listener <- NewDirectSignal("listener_closed")
close(listener) close(listener)
@ -239,6 +239,19 @@ func (node * GraphNode) UnregisterChannel(id NodeID) {
node.listeners_lock.Unlock() node.listeners_lock.Unlock()
} }
func AttachPolicies(ctx *Context, node Node, policies ...Policy) error {
context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error {
for _, policy := range(policies) {
err := node.AddPolicy(policy)
if err != nil {
return err
}
}
return nil
})
}
func NewGraphNode(id NodeID) GraphNode { func NewGraphNode(id NodeID) GraphNode {
return GraphNode{ return GraphNode{
id: id, id: id,
@ -577,16 +590,16 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if already_granted == false { if already_granted == false {
err := node.Allowed("read", resource, princ) err := node.Allowed("read", resource, princ)
if err != nil { if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.read", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock() n.RUnlock()
} }
return err return err
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.read", princ.ID().String(), id.String(), resource)
} else { } else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.read", princ.ID().String(), id.String(), resource)
} }
} }
new_permissions[id] = node_permissions new_permissions[id] = node_permissions
@ -681,16 +694,16 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if already_granted == false { if already_granted == false {
err := node.Allowed("write", resource, princ) err := node.Allowed("write", resource, princ)
if err != nil { if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.write", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock() n.Unlock()
} }
return err return err
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.write", princ.ID().String(), id.String(), resource)
} else { } else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.write", id.String(), resource) context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.write", princ.ID().String(), id.String(), resource)
} }
} }
new_permissions[id] = node_permissions new_permissions[id] = node_permissions

@ -62,7 +62,7 @@ func NewDirectSignal(_type string) BaseSignal {
} }
var AbortSignal = NewBaseSignal("abort", Down) var AbortSignal = NewBaseSignal("abort", Down)
var CancelSignal = NewBaseSignal("cancel", Down) var StopSignal = NewBaseSignal("stop", Down)
type IDSignal struct { type IDSignal struct {
BaseSignal BaseSignal

@ -595,21 +595,20 @@ func ThreadStartChild(ctx *Context, thread Thread, signal StartChildSignal) erro
// Helper function to restore threads that should be running from a parents restore action // Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread Thread) { func ThreadRestore(ctx * Context, thread Thread) error {
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
UpdateStates(context, thread, NewLockMap( return UpdateStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
NewLockInfo(thread, []string{"children"}), return UpdateStates(context, thread, LockList(thread.Children(), []string{"start"}), func(context *StateContext) error {
LockList(thread.Children(), []string{"start"}), for _, child := range(thread.Children()) {
), func(context *StateContext)(error) { should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent()
for _, child := range(thread.Children()) { ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run)
should_run := (thread.ChildInfo(child.ID())).(ParentInfo).Parent() if should_run.Start == true && child.State() != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORE: %s -> %s: %+v", thread.ID(), child.ID(), should_run) ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID())
if should_run.Start == true && child.State() != "finished" { ChildGo(ctx, thread, child, should_run.RestoreAction)
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), child.ID()) }
ChildGo(ctx, thread, child, should_run.RestoreAction)
} }
} return nil
return nil })
}) })
} }
@ -697,14 +696,14 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
if err != nil { if err != nil {
return "", err return "", err
} }
return "finish", ThreadAbortedError return "", ThreadAbortedError
} }
// Default thread action for "cancel", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewSignal("cancelled")) return thread.Signal(context, NewSignal("stopped"))
}) })
return "finish", err return "finish", err
} }
@ -740,5 +739,5 @@ var BaseThreadActions = ThreadActions{
// Default thread signal handlers // Default thread signal handlers
var BaseThreadHandlers = ThreadHandlers{ var BaseThreadHandlers = ThreadHandlers{
"abort": ThreadAbort, "abort": ThreadAbort,
"cancel": ThreadCancel, "stop": ThreadStop,
} }