diff --git a/gql.go b/gql.go index 5326aea..244802e 100644 --- a/gql.go +++ b/gql.go @@ -906,12 +906,7 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = UseStates(context, server, NewLockMap( - NewLockInfo(server, []string{"signal"}), - ), func(context *StateContext) error { - return server.Signal(context, NewSignal("server_started")) - }) - + err = server.Signal(context, server, NewStatusSignal("server_started", server.ID())) if err != nil { return "", err } diff --git a/gql_mutation.go b/gql_mutation.go index c39cff6..17a1453 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -18,11 +18,6 @@ var GQLMutationAbort = NewField(func()*graphql.Field { return nil, err } - err = ctx.Server.Allowed("signal", "", ctx.User) - if err != nil { - return nil, err - } - id, err := ExtractID(p, "id") if err != nil { return nil, err @@ -37,9 +32,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field { if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error { - return node.Signal(context, AbortSignal) - }) + return node.Signal(context, ctx.User, AbortSignal) }) if err != nil { return nil, err @@ -98,15 +91,8 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) } - err := node.Allowed("signal", "", ctx.User) - if err != nil { - return err - } - - return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error { - signal = NewStartChildSignal(child_id, action) - return node.Signal(context, signal) - }) + signal = NewStartChildSignal(child_id, action) + return node.Signal(context, ctx.User, signal) }) if err != nil { return nil, err diff --git a/gql_query.go b/gql_query.go index 4593e10..f856e10 100644 --- a/gql_query.go +++ b/gql_query.go @@ -11,11 +11,6 @@ var GQLQuerySelf = &graphql.Field{ return nil, err } - err = ctx.Server.Allowed("read", "", ctx.User) - if err != nil { - return nil, err - } - return ctx.Server, nil }, } @@ -28,11 +23,6 @@ var GQLQueryUser = &graphql.Field{ return nil, err } - err = ctx.User.Allowed("read", "", ctx.User) - if err != nil { - return nil, err - } - return ctx.User, nil }, } diff --git a/gql_test.go b/gql_test.go index 658dd9d..db91bdf 100644 --- a/gql_test.go +++ b/gql_test.go @@ -18,7 +18,7 @@ import ( ) func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"test", "signal", "thread"}) + ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) @@ -91,19 +91,16 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) context = NewReadContext(ctx) - err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error { - err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID())) - if err != nil { - return nil - } - return gql.Signal(context, StopSignal) - }) + err = gql.Signal(context, gql, NewStatusSignal("child_linked", t1.ID())) + fatalErr(t, err) + context = NewReadContext(ctx) + err = gql.Signal(context, gql, StopSignal) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") fatalErr(t, err) - (*GraphTester)(t).WaitForValue(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel") + (*GraphTester)(t).WaitForStatus(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel") context = NewReadContext(ctx) err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error { @@ -132,13 +129,13 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) - gql_loaded.Signal(context, StopSignal) + gql_loaded.Signal(context, gql_loaded, StopSignal) return err }) err = ThreadLoop(ctx, gql_loaded.(Thread), "start") fatalErr(t, err) - (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") + (*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") } @@ -178,14 +175,12 @@ func TestGQLAuth(t * testing.T) { ctx.Log.Logf("test", "DONE") } context := NewReadContext(ctx) - err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error { - return thread.Signal(context, StopSignal) - }) + err := thread.Signal(context, thread, StopSignal) fatalErr(t, err) }(done, gql_t) go func(thread Thread){ - (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") + (*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start") port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port ctx.Log.Logf("test", "GQL_PORT: %d", port) diff --git a/graph_test.go b/graph_test.go index 6f440e0..e222dec 100644 --- a/graph_test.go +++ b/graph_test.go @@ -13,7 +13,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, status string, timeout time.Duration, str string) GraphSignal { timeout_channel := time.After(timeout) for true { select { @@ -22,8 +22,16 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) t.Fatal(str) } - if signal.Type() == signal_type { - return signal + if signal.Type() == "status" { + sig, ok := signal.(StatusSignal) + if ok == true { + if sig.Status == status { + return signal + } + ctx.Log.Logf("test", "Different status received: %s", sig.Status) + } else { + ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal) + } } case <-timeout_channel: pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) diff --git a/lockable.go b/lockable.go index b5d73de..7d98bda 100644 --- a/lockable.go +++ b/lockable.go @@ -216,22 +216,20 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error { } // Assumed that lockable is already locked for signal -func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error { - err := lockable.GraphNode.Signal(context, signal) +func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signal GraphSignal) error { + err := lockable.GraphNode.Signal(context, princ, signal) if err != nil { return err } switch signal.Direction() { case Up: - err = UseStates(context, lockable, NewLockMap( - NewLockInfo(lockable, []string{"dependencies", "owner"}), - LockList(lockable.requirements, []string{"signal"}), - ), func(context *StateContext) error { + err = UseStates(context, lockable, + NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error { owner_sent := false for _, dependency := range(lockable.dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) - dependency.Signal(context, signal) + dependency.Signal(context, lockable, signal) if lockable.owner != nil { if dependency.ID() == lockable.owner.ID() { owner_sent = true @@ -241,20 +239,15 @@ func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSigna if lockable.owner != nil && owner_sent == false { if lockable.owner.ID() != lockable.ID() { context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID()) - return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error { - return lockable.owner.Signal(context, signal) - }) + return lockable.owner.Signal(context, lockable, signal) } } return nil }) case Down: - err = UseStates(context, lockable, NewLockMap( - NewLockInfo(lockable, []string{"requirements"}), - LockList(lockable.requirements, []string{"signal"}), - ), func(context *StateContext) error { + err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { for _, requirement := range(lockable.requirements) { - err := requirement.Signal(context, signal) + err := requirement.Signal(context, lockable, signal) if err != nil { return err } diff --git a/node.go b/node.go index 35a8977..0a54d14 100644 --- a/node.go +++ b/node.go @@ -69,12 +69,12 @@ type Node interface { ID() NodeID Type() NodeType - Allowed(action string, resource string, principal Node) error + Allowed(context *StateContext, action string, resource string, principal Node) error AddPolicy(Policy) error RemovePolicy(Policy) error // Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate - Signal(context *StateContext, signal GraphSignal) error + Signal(context *StateContext, princ Node, signal GraphSignal) error // Register a channel to receive updates sent to the node RegisterChannel(id NodeID, listener chan GraphSignal) // Unregister a channel from receiving updates sent to the node @@ -100,19 +100,22 @@ func (node * GraphNode) Serialize() ([]byte, error) { return json.MarshalIndent(&node_json, "", " ") } -func (node *GraphNode) Allowed(action string, resource string, principal Node) error { - if principal == nil { +func (node *GraphNode) Allowed(context *StateContext, resource string, action string, princ Node) error { + if princ == nil { + context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) return fmt.Errorf("nil is not allowed to perform any actions") } - if node.ID() == principal.ID() { + if node.ID() == princ.ID() { return nil } for _, policy := range(node.policies) { - if policy.Allows(resource, action, principal) == true { + if policy.Allows(resource, action, princ) == true { + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) return nil } } - return fmt.Errorf("%s is not allowed to perform %s.%s on %s", principal.ID().String(), resource, action, node.ID().String()) + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) + return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID()) } func (node *GraphNode) AddPolicy(policy Policy) error { @@ -193,8 +196,17 @@ func (node * GraphNode) Type() NodeType { // Propagate the signal to registered listeners, if a listener isn't ready to receive the update // send it a notification that it was closed and then close it -func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error { +func (node * GraphNode) Signal(context *StateContext, princ Node, signal GraphSignal) error { context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) + + err := UseStates(context, princ, NewLockInfo(princ, nil), func(context *StateContext) error { + return node.Allowed(context, "signal", signal.Type(), princ) + }) + + if err != nil { + return nil + } + node.listeners_lock.Lock() defer node.listeners_lock.Unlock() closed := []NodeID{} @@ -536,10 +548,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St return err } - final := false if context.Started == false { context.Started = true - final = true } new_locks := []Node{} @@ -588,18 +598,14 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St } if already_granted == false { - err := node.Allowed("read", resource, princ) + err := node.Allowed(context, resource, "read", princ) if err != nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.read", princ.ID().String(), id.String(), resource) for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) n.RUnlock() } return err } - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.read", princ.ID().String(), id.String(), resource) - } else { - context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.read", princ.ID().String(), id.String(), resource) } } new_permissions[id] = node_permissions @@ -621,10 +627,6 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St node.RUnlock() } - if final == true { - context.Finished = true - } - return err } @@ -692,18 +694,14 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } if already_granted == false { - err := node.Allowed("write", resource, princ) + err := node.Allowed(context, resource, "write", princ) if err != nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.write", princ.ID().String(), id.String(), resource) for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) n.Unlock() } return err } - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.write", princ.ID().String(), id.String(), resource) - } else { - context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.write", princ.ID().String(), id.String(), resource) } } new_permissions[id] = node_permissions diff --git a/signal.go b/signal.go index 81446c1..133c3bd 100644 --- a/signal.go +++ b/signal.go @@ -76,8 +76,16 @@ func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { } } -func NewStatusSignal(_type string, source NodeID) IDSignal { - return NewIDSignal(_type, Up, source) +type StatusSignal struct { + IDSignal + Status string +} + +func NewStatusSignal(status string, source NodeID) StatusSignal { + return StatusSignal{ + IDSignal: NewIDSignal("status", Up, source), + Status: status, + } } type StartChildSignal struct { diff --git a/thread.go b/thread.go index a0688d1..e4d81cf 100644 --- a/thread.go +++ b/thread.go @@ -10,8 +10,8 @@ import ( ) // Assumed that thread is already locked for signal -func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error { - err := thread.SimpleLockable.Signal(context, signal) +func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal GraphSignal) error { + err := thread.SimpleLockable.Signal(context, princ, signal) if err != nil { return err } @@ -20,20 +20,15 @@ func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) er case Up: err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { if thread.parent != nil { - return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error { - return thread.parent.Signal(context, signal) - }) + return thread.parent.Signal(context, thread, signal) } else { return nil } }) case Down: - err = UseStates(context, thread, NewLockMap( - NewLockInfo(thread, []string{"children"}), - LockList(thread.children, []string{"signal"}), - ), func(context *StateContext) error { + err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { for _, child := range(thread.children) { - err := child.Signal(context, signal) + err := child.Signal(context, thread, signal) if err != nil { return err } @@ -690,9 +685,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { - return thread.Signal(context, NewStatusSignal("aborted", thread.ID())) - }) + err := thread.Signal(context, thread, NewStatusSignal("aborted", thread.ID())) if err != nil { return "", err } @@ -702,9 +695,7 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro // Default thread action for "stop", sends a signal and returns no error func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error { - return thread.Signal(context, NewSignal("stopped")) - }) + err := thread.Signal(context, thread, NewStatusSignal("stopped", thread.ID())) return "finish", err }