diff --git a/gql.go b/gql.go index c8e7c3e..8faa230 100644 --- a/gql.go +++ b/gql.go @@ -895,7 +895,7 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = server.Signal(context, server, NewStatusSignal("server_started", server.ID())) + err = Signal(context, server, server, NewStatusSignal("server_started", server.ID())) if err != nil { return "", err } diff --git a/gql_mutation.go b/gql_mutation.go index 17a1453..ebbcdc8 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -32,7 +32,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field { if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - return node.Signal(context, ctx.User, AbortSignal) + return Signal(context, node, ctx.User, AbortSignal) }) if err != nil { return nil, err @@ -92,7 +92,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } signal = NewStartChildSignal(child_id, action) - return node.Signal(context, ctx.User, signal) + return Signal(context, node, ctx.User, signal) }) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index 447a34a..877637a 100644 --- a/gql_test.go +++ b/gql_test.go @@ -92,10 +92,10 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) context = NewReadContext(ctx) - err = gql.Signal(context, gql, NewStatusSignal("child_linked", t1.ID())) + err = Signal(context, gql, gql, NewStatusSignal("child_linked", t1.ID())) fatalErr(t, err) context = NewReadContext(ctx) - err = gql.Signal(context, gql, AbortSignal) + err = Signal(context, gql, gql, AbortSignal) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") @@ -132,7 +132,7 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) - gql_loaded.Signal(context, gql_loaded, StopSignal) + Signal(context, gql_loaded, gql_loaded, StopSignal) return err }) @@ -178,7 +178,7 @@ func TestGQLAuth(t * testing.T) { ctx.Log.Logf("test", "DONE") } context := NewReadContext(ctx) - err := thread.Signal(context, thread, StopSignal) + err := Signal(context, thread, thread, StopSignal) fatalErr(t, err) }(done, gql_t) diff --git a/lockable.go b/lockable.go index b192b00..09c3beb 100644 --- a/lockable.go +++ b/lockable.go @@ -216,8 +216,8 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { } // Assumed that lockable is already locked for signal -func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signal GraphSignal) error { - err := lockable.GraphNode.Signal(context, princ, signal) +func (lockable * SimpleLockable) Process(context *StateContext, princ Node, signal GraphSignal) error { + err := lockable.GraphNode.Process(context, princ, signal) if err != nil { return err } @@ -229,7 +229,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa owner_sent := false for _, dependency := range(lockable.dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) - dependency.Signal(context, lockable, signal) + Signal(context, dependency, lockable, signal) if lockable.owner != nil { if dependency.ID() == lockable.owner.ID() { owner_sent = true @@ -239,7 +239,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa if lockable.owner != nil && owner_sent == false { if lockable.owner.ID() != lockable.ID() { context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - return lockable.owner.Signal(context, lockable, signal) + return Signal(context, lockable.owner, lockable, signal) } } return nil @@ -247,7 +247,7 @@ func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signa case Down: err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { for _, requirement := range(lockable.requirements) { - err := requirement.Signal(context, lockable, signal) + err := Signal(context, requirement, lockable, signal) if err != nil { return err } diff --git a/node.go b/node.go index 263999c..d7a4dbc 100644 --- a/node.go +++ b/node.go @@ -69,12 +69,12 @@ type Node interface { ID() NodeID Type() NodeType - Allowed(context *StateContext, action string, resource string, principal Node) error + Policies() map[NodeID]Policy AddPolicy(Policy) error RemovePolicy(Policy) error // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Signal(context *StateContext, princ Node, signal GraphSignal) error + Process(context *StateContext, princ Node, signal GraphSignal) error // Register a channel to receive updates sent to the node RegisterChannel(id NodeID, listener chan GraphSignal) // Unregister a channel from receiving updates sent to the node @@ -95,12 +95,16 @@ type GraphNodeJSON struct { Policies []string `json:"policies"` } +func (node * GraphNode) Policies() map[NodeID]Policy { + return node.policies +} + func (node * GraphNode) Serialize() ([]byte, error) { node_json := NewGraphNodeJSON(node) return json.MarshalIndent(&node_json, "", " ") } -func (node *GraphNode) Allowed(context *StateContext, resource string, action string, princ Node) error { +func Allowed(context *StateContext, policies map[NodeID]Policy, node Node, resource string, action string, princ Node) error { if princ == nil { context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) return fmt.Errorf("nil is not allowed to perform any actions") @@ -108,7 +112,7 @@ func (node *GraphNode) Allowed(context *StateContext, resource string, action st if node.ID() == princ.ID() { return nil } - for _, policy := range(node.policies) { + for _, policy := range(policies) { if policy.Allows(node, resource, action, princ) == true { context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) return nil @@ -196,17 +200,21 @@ func (node * GraphNode) Type() NodeType { // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it -func (node * GraphNode) Signal(context *StateContext, princ Node, signal GraphSignal) error { +func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error { context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) - err := UseStates(context, princ, NewLockInfo(princ, nil), func(context *StateContext) error { - return node.Allowed(context, "signal", signal.Type(), princ) + err := UseStates(context, princ, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error { + return Allowed(context, node.Policies(), node, "signal", signal.Type(), princ) }) if err != nil { return nil } + return node.Process(context, princ, signal) +} + +func (node * GraphNode) Process(context *StateContext, princ Node, signal GraphSignal) error { node.listeners_lock.Lock() defer node.listeners_lock.Unlock() closed := []NodeID{} @@ -598,7 +606,7 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St } if already_granted == false { - err := node.Allowed(context, resource, "read", princ) + err := Allowed(context, node.Policies(), node, resource, "read", princ) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) @@ -694,7 +702,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } if already_granted == false { - err := node.Allowed(context, resource, "write", princ) + err := Allowed(context, node.Policies(), node, resource, "write", princ) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) diff --git a/policy.go b/policy.go index 42a5360..ba33020 100644 --- a/policy.go +++ b/policy.go @@ -273,8 +273,8 @@ func (policy *DependencyPolicy) Allows(node Node, resource string, action string return false } - for _, req := range(lockable.Dependencies()) { - if req.ID() == principal.ID() { + for _, dep := range(lockable.Dependencies()) { + if dep.ID() == principal.ID() { return policy.Actions.Allows(resource, action) } } diff --git a/thread.go b/thread.go index 59d4608..71ce97d 100644 --- a/thread.go +++ b/thread.go @@ -10,8 +10,8 @@ import ( ) // Assumed that thread is already locked for signal -func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal GraphSignal) error { - err := thread.SimpleLockable.Signal(context, princ, signal) +func (thread *SimpleThread) Process(context *StateContext, princ Node, signal GraphSignal) error { + err := thread.SimpleLockable.Process(context, princ, signal) if err != nil { return err } @@ -20,7 +20,7 @@ func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal Gra case Up: err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { if thread.parent != nil { - return thread.parent.Signal(context, thread, signal) + return Signal(context, thread.parent, thread, signal) } else { return nil } @@ -28,7 +28,7 @@ func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal Gra case Down: err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { for _, child := range(thread.children) { - err := child.Signal(context, thread, signal) + err := Signal(context, child, thread, signal) if err != nil { return err } @@ -741,7 +741,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := thread.Signal(context, thread, NewStatusSignal("aborted", thread.ID())) + err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID())) if err != nil { return "", err } @@ -751,7 +751,7 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro // Default thread action for "stop", sends a signal and returns no error func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := thread.Signal(context, thread, NewStatusSignal("stopped", thread.ID())) + err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID())) return "finish", err }