Attempt to fix DependencyPolicy

graph-rework-2
noah metz 2023-07-24 01:41:47 -06:00
parent c64dd728ed
commit fa10ccd743
7 changed files with 37 additions and 29 deletions

@ -895,7 +895,7 @@ var gql_actions ThreadActions = ThreadActions{
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = server.Signal(context, server, NewStatusSignal("server_started", server.ID())) err = Signal(context, server, server, NewStatusSignal("server_started", server.ID()))
if err != nil { if err != nil {
return "", err return "", err
} }

@ -32,7 +32,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
return node.Signal(context, ctx.User, AbortSignal) return Signal(context, node, ctx.User, AbortSignal)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -92,7 +92,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
} }
signal = NewStartChildSignal(child_id, action) signal = NewStartChildSignal(child_id, action)
return node.Signal(context, ctx.User, signal) return Signal(context, node, ctx.User, signal)
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -92,10 +92,10 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = gql.Signal(context, gql, NewStatusSignal("child_linked", t1.ID())) err = Signal(context, gql, gql, NewStatusSignal("child_linked", t1.ID()))
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = gql.Signal(context, gql, AbortSignal) err = Signal(context, gql, gql, AbortSignal)
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start") err = ThreadLoop(ctx, gql, "start")
@ -132,7 +132,7 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
return err return err
}) })
gql_loaded.Signal(context, gql_loaded, StopSignal) Signal(context, gql_loaded, gql_loaded, StopSignal)
return err return err
}) })
@ -178,7 +178,7 @@ func TestGQLAuth(t * testing.T) {
ctx.Log.Logf("test", "DONE") ctx.Log.Logf("test", "DONE")
} }
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := thread.Signal(context, thread, StopSignal) err := Signal(context, thread, thread, StopSignal)
fatalErr(t, err) fatalErr(t, err)
}(done, gql_t) }(done, gql_t)

@ -216,8 +216,8 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
} }
// Assumed that lockable is already locked for signal // Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signal GraphSignal) error { func (lockable * SimpleLockable) Process(context *StateContext, princ Node, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, princ, signal) err := lockable.GraphNode.Process(context, princ, signal)
if err != nil { if err != nil {
return err return err
} }
@ -229,7 +229,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa
owner_sent := false owner_sent := false
for _, dependency := range(lockable.dependencies) { for _, dependency := range(lockable.dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
dependency.Signal(context, lockable, signal) Signal(context, dependency, lockable, signal)
if lockable.owner != nil { if lockable.owner != nil {
if dependency.ID() == lockable.owner.ID() { if dependency.ID() == lockable.owner.ID() {
owner_sent = true owner_sent = true
@ -239,7 +239,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa
if lockable.owner != nil && owner_sent == false { if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() { if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return lockable.owner.Signal(context, lockable, signal) return Signal(context, lockable.owner, lockable, signal)
} }
} }
return nil return nil
@ -247,7 +247,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa
case Down: case Down:
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) { for _, requirement := range(lockable.requirements) {
err := requirement.Signal(context, lockable, signal) err := Signal(context, requirement, lockable, signal)
if err != nil { if err != nil {
return err return err
} }

@ -69,12 +69,12 @@ type Node interface {
ID() NodeID ID() NodeID
Type() NodeType Type() NodeType
Allowed(context *StateContext, action string, resource string, principal Node) error Policies() map[NodeID]Policy
AddPolicy(Policy) error AddPolicy(Policy) error
RemovePolicy(Policy) error RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Signal(context *StateContext, princ Node, signal GraphSignal) error Process(context *StateContext, princ Node, signal GraphSignal) error
// Register a channel to receive updates sent to the node // Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal) RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node // Unregister a channel from receiving updates sent to the node
@ -95,12 +95,16 @@ type GraphNodeJSON struct {
Policies []string `json:"policies"` Policies []string `json:"policies"`
} }
func (node * GraphNode) Policies() map[NodeID]Policy {
return node.policies
}
func (node * GraphNode) Serialize() ([]byte, error) { func (node * GraphNode) Serialize() ([]byte, error) {
node_json := NewGraphNodeJSON(node) node_json := NewGraphNodeJSON(node)
return json.MarshalIndent(&node_json, "", " ") return json.MarshalIndent(&node_json, "", " ")
} }
func (node *GraphNode) Allowed(context *StateContext, resource string, action string, princ Node) error { func Allowed(context *StateContext, policies map[NodeID]Policy, node Node, resource string, action string, princ Node) error {
if princ == nil { if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions") return fmt.Errorf("nil is not allowed to perform any actions")
@ -108,7 +112,7 @@ func (node *GraphNode) Allowed(context *StateContext, resource string, action st
if node.ID() == princ.ID() { if node.ID() == princ.ID() {
return nil return nil
} }
for _, policy := range(node.policies) { for _, policy := range(policies) {
if policy.Allows(node, resource, action, princ) == true { if policy.Allows(node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil return nil
@ -196,17 +200,21 @@ func (node * GraphNode) Type() NodeType {
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update // Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it // send it a notification that it was closed and then close it
func (node * GraphNode) Signal(context *StateContext, princ Node, signal GraphSignal) error { func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
err := UseStates(context, princ, NewLockInfo(princ, nil), func(context *StateContext) error { err := UseStates(context, princ, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error {
return node.Allowed(context, "signal", signal.Type(), princ) return Allowed(context, node.Policies(), node, "signal", signal.Type(), princ)
}) })
if err != nil { if err != nil {
return nil return nil
} }
return node.Process(context, princ, signal)
}
func (node * GraphNode) Process(context *StateContext, princ Node, signal GraphSignal) error {
node.listeners_lock.Lock() node.listeners_lock.Lock()
defer node.listeners_lock.Unlock() defer node.listeners_lock.Unlock()
closed := []NodeID{} closed := []NodeID{}
@ -598,7 +606,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
} }
if already_granted == false { if already_granted == false {
err := node.Allowed(context, resource, "read", princ) err := Allowed(context, node.Policies(), node, resource, "read", princ)
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
@ -694,7 +702,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
if already_granted == false { if already_granted == false {
err := node.Allowed(context, resource, "write", princ) err := Allowed(context, node.Policies(), node, resource, "write", princ)
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())

@ -273,8 +273,8 @@ func (policy *DependencyPolicy) Allows(node Node, resource string, action string
return false return false
} }
for _, req := range(lockable.Dependencies()) { for _, dep := range(lockable.Dependencies()) {
if req.ID() == principal.ID() { if dep.ID() == principal.ID() {
return policy.Actions.Allows(resource, action) return policy.Actions.Allows(resource, action)
} }
} }

@ -10,8 +10,8 @@ import (
) )
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal GraphSignal) error { func (thread *SimpleThread) Process(context *StateContext, princ Node, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, princ, signal) err := thread.SimpleLockable.Process(context, princ, signal)
if err != nil { if err != nil {
return err return err
} }
@ -20,7 +20,7 @@ func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal Gra
case Up: case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil { if thread.parent != nil {
return thread.parent.Signal(context, thread, signal) return Signal(context, thread.parent, thread, signal)
} else { } else {
return nil return nil
} }
@ -28,7 +28,7 @@ func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal Gra
case Down: case Down:
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, child := range(thread.children) { for _, child := range(thread.children) {
err := child.Signal(context, thread, signal) err := Signal(context, child, thread, signal)
if err != nil { if err != nil {
return err return err
} }
@ -741,7 +741,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := thread.Signal(context, thread, NewStatusSignal("aborted", thread.ID())) err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID()))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -751,7 +751,7 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
// Default thread action for "stop", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := thread.Signal(context, thread, NewStatusSignal("stopped", thread.ID())) err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID()))
return "finish", err return "finish", err
} }