Moved status signals to StatusSignal type with status string as type, so they can be ACLd

graph-rework-2
noah metz 2023-07-23 21:14:15 -06:00
parent d56245c5cf
commit dfa420757b
9 changed files with 72 additions and 108 deletions

@ -906,12 +906,7 @@ var gql_actions ThreadActions = ThreadActions{
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, server, NewLockMap( err = server.Signal(context, server, NewStatusSignal("server_started", server.ID()))
NewLockInfo(server, []string{"signal"}),
), func(context *StateContext) error {
return server.Signal(context, NewSignal("server_started"))
})
if err != nil { if err != nil {
return "", err return "", err
} }

@ -18,11 +18,6 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
return nil, err return nil, err
} }
err = ctx.Server.Allowed("signal", "", ctx.User)
if err != nil {
return nil, err
}
id, err := ExtractID(p, "id") id, err := ExtractID(p, "id")
if err != nil { if err != nil {
return nil, err return nil, err
@ -37,9 +32,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error { return node.Signal(context, ctx.User, AbortSignal)
return node.Signal(context, AbortSignal)
})
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,15 +91,8 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
} }
err := node.Allowed("signal", "", ctx.User) signal = NewStartChildSignal(child_id, action)
if err != nil { return node.Signal(context, ctx.User, signal)
return err
}
return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error {
signal = NewStartChildSignal(child_id, action)
return node.Signal(context, signal)
})
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -11,11 +11,6 @@ var GQLQuerySelf = &graphql.Field{
return nil, err return nil, err
} }
err = ctx.Server.Allowed("read", "", ctx.User)
if err != nil {
return nil, err
}
return ctx.Server, nil return ctx.Server, nil
}, },
} }
@ -28,11 +23,6 @@ var GQLQueryUser = &graphql.Field{
return nil, err return nil, err
} }
err = ctx.User.Allowed("read", "", ctx.User)
if err != nil {
return nil, err
}
return ctx.User, nil return ctx.User, nil
}, },
} }

@ -18,7 +18,7 @@ import (
) )
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "thread"}) ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r l1 := &l1_r
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
@ -91,19 +91,16 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error { err = gql.Signal(context, gql, NewStatusSignal("child_linked", t1.ID()))
err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID())) fatalErr(t, err)
if err != nil { context = NewReadContext(ctx)
return nil err = gql.Signal(context, gql, StopSignal)
}
return gql.Signal(context, StopSignal)
})
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start") err = ThreadLoop(ctx, gql, "start")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel") (*GraphTester)(t).WaitForStatus(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel")
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
@ -132,13 +129,13 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
return err return err
}) })
gql_loaded.Signal(context, StopSignal) gql_loaded.Signal(context, gql_loaded, StopSignal)
return err return err
}) })
err = ThreadLoop(ctx, gql_loaded.(Thread), "start") err = ThreadLoop(ctx, gql_loaded.(Thread), "start")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") (*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
} }
@ -178,14 +175,12 @@ func TestGQLAuth(t * testing.T) {
ctx.Log.Logf("test", "DONE") ctx.Log.Logf("test", "DONE")
} }
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error { err := thread.Signal(context, thread, StopSignal)
return thread.Signal(context, StopSignal)
})
fatalErr(t, err) fatalErr(t, err)
}(done, gql_t) }(done, gql_t)
go func(thread Thread){ go func(thread Thread){
(*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") (*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port) ctx.Log.Logf("test", "GQL_PORT: %d", port)

@ -13,7 +13,7 @@ import (
type GraphTester testing.T type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond const listner_timeout = 50 * time.Millisecond
func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, timeout time.Duration, str string) GraphSignal { func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, status string, timeout time.Duration, str string) GraphSignal {
timeout_channel := time.After(timeout) timeout_channel := time.After(timeout)
for true { for true {
select { select {
@ -22,8 +22,16 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str) t.Fatal(str)
} }
if signal.Type() == signal_type { if signal.Type() == "status" {
return signal sig, ok := signal.(StatusSignal)
if ok == true {
if sig.Status == status {
return signal
}
ctx.Log.Logf("test", "Different status received: %s", sig.Status)
} else {
ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal)
}
} }
case <-timeout_channel: case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)

@ -216,22 +216,20 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
} }
// Assumed that lockable is already locked for signal // Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error { func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, signal) err := lockable.GraphNode.Signal(context, princ, signal)
if err != nil { if err != nil {
return err return err
} }
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, lockable, NewLockMap( err = UseStates(context, lockable,
NewLockInfo(lockable, []string{"dependencies", "owner"}), NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error {
LockList(lockable.requirements, []string{"signal"}),
), func(context *StateContext) error {
owner_sent := false owner_sent := false
for _, dependency := range(lockable.dependencies) { for _, dependency := range(lockable.dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
dependency.Signal(context, signal) dependency.Signal(context, lockable, signal)
if lockable.owner != nil { if lockable.owner != nil {
if dependency.ID() == lockable.owner.ID() { if dependency.ID() == lockable.owner.ID() {
owner_sent = true owner_sent = true
@ -241,20 +239,15 @@ func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSigna
if lockable.owner != nil && owner_sent == false { if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() { if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error { return lockable.owner.Signal(context, lockable, signal)
return lockable.owner.Signal(context, signal)
})
} }
} }
return nil return nil
}) })
case Down: case Down:
err = UseStates(context, lockable, NewLockMap( err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
NewLockInfo(lockable, []string{"requirements"}),
LockList(lockable.requirements, []string{"signal"}),
), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) { for _, requirement := range(lockable.requirements) {
err := requirement.Signal(context, signal) err := requirement.Signal(context, lockable, signal)
if err != nil { if err != nil {
return err return err
} }

@ -69,12 +69,12 @@ type Node interface {
ID() NodeID ID() NodeID
Type() NodeType Type() NodeType
Allowed(action string, resource string, principal Node) error Allowed(context *StateContext, action string, resource string, principal Node) error
AddPolicy(Policy) error AddPolicy(Policy) error
RemovePolicy(Policy) error RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Signal(context *StateContext, signal GraphSignal) error Signal(context *StateContext, princ Node, signal GraphSignal) error
// Register a channel to receive updates sent to the node // Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal) RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node // Unregister a channel from receiving updates sent to the node
@ -100,19 +100,22 @@ func (node * GraphNode) Serialize() ([]byte, error) {
return json.MarshalIndent(&node_json, "", " ") return json.MarshalIndent(&node_json, "", " ")
} }
func (node *GraphNode) Allowed(action string, resource string, principal Node) error { func (node *GraphNode) Allowed(context *StateContext, resource string, action string, princ Node) error {
if principal == nil { if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions") return fmt.Errorf("nil is not allowed to perform any actions")
} }
if node.ID() == principal.ID() { if node.ID() == princ.ID() {
return nil return nil
} }
for _, policy := range(node.policies) { for _, policy := range(node.policies) {
if policy.Allows(resource, action, principal) == true { if policy.Allows(resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil return nil
} }
} }
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", principal.ID().String(), resource, action, node.ID().String()) context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
} }
func (node *GraphNode) AddPolicy(policy Policy) error { func (node *GraphNode) AddPolicy(policy Policy) error {
@ -193,8 +196,17 @@ func (node * GraphNode) Type() NodeType {
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update // Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it // send it a notification that it was closed and then close it
func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error { func (node * GraphNode) Signal(context *StateContext, princ Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
err := UseStates(context, princ, NewLockInfo(princ, nil), func(context *StateContext) error {
return node.Allowed(context, "signal", signal.Type(), princ)
})
if err != nil {
return nil
}
node.listeners_lock.Lock() node.listeners_lock.Lock()
defer node.listeners_lock.Unlock() defer node.listeners_lock.Unlock()
closed := []NodeID{} closed := []NodeID{}
@ -536,10 +548,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
return err return err
} }
final := false
if context.Started == false { if context.Started == false {
context.Started = true context.Started = true
final = true
} }
new_locks := []Node{} new_locks := []Node{}
@ -588,18 +598,14 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
} }
if already_granted == false { if already_granted == false {
err := node.Allowed("read", resource, princ) err := node.Allowed(context, resource, "read", princ)
if err != nil { if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.read", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock() n.RUnlock()
} }
return err return err
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.read", princ.ID().String(), id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.read", princ.ID().String(), id.String(), resource)
} }
} }
new_permissions[id] = node_permissions new_permissions[id] = node_permissions
@ -621,10 +627,6 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
node.RUnlock() node.RUnlock()
} }
if final == true {
context.Finished = true
}
return err return err
} }
@ -692,18 +694,14 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
if already_granted == false { if already_granted == false {
err := node.Allowed("write", resource, princ) err := node.Allowed(context, resource, "write", princ)
if err != nil { if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.write", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock() n.Unlock()
} }
return err return err
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.write", princ.ID().String(), id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.write", princ.ID().String(), id.String(), resource)
} }
} }
new_permissions[id] = node_permissions new_permissions[id] = node_permissions

@ -76,8 +76,16 @@ func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal {
} }
} }
func NewStatusSignal(_type string, source NodeID) IDSignal { type StatusSignal struct {
return NewIDSignal(_type, Up, source) IDSignal
Status string
}
func NewStatusSignal(status string, source NodeID) StatusSignal {
return StatusSignal{
IDSignal: NewIDSignal("status", Up, source),
Status: status,
}
} }
type StartChildSignal struct { type StartChildSignal struct {

@ -10,8 +10,8 @@ import (
) )
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error { func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, signal) err := thread.SimpleLockable.Signal(context, princ, signal)
if err != nil { if err != nil {
return err return err
} }
@ -20,20 +20,15 @@ func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) er
case Up: case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil { if thread.parent != nil {
return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error { return thread.parent.Signal(context, thread, signal)
return thread.parent.Signal(context, signal)
})
} else { } else {
return nil return nil
} }
}) })
case Down: case Down:
err = UseStates(context, thread, NewLockMap( err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
NewLockInfo(thread, []string{"children"}),
LockList(thread.children, []string{"signal"}),
), func(context *StateContext) error {
for _, child := range(thread.children) { for _, child := range(thread.children) {
err := child.Signal(context, signal) err := child.Signal(context, thread, signal)
if err != nil { if err != nil {
return err return err
} }
@ -690,9 +685,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { err := thread.Signal(context, thread, NewStatusSignal("aborted", thread.ID()))
return thread.Signal(context, NewStatusSignal("aborted", thread.ID()))
})
if err != nil { if err != nil {
return "", err return "", err
} }
@ -702,9 +695,7 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
// Default thread action for "stop", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { err := thread.Signal(context, thread, NewStatusSignal("stopped", thread.ID()))
return thread.Signal(context, NewSignal("stopped"))
})
return "finish", err return "finish", err
} }