Moved status signals to StatusSignal type with status string as type, so they can be ACLd

graph-rework-2
noah metz 2023-07-23 21:14:15 -06:00
parent d56245c5cf
commit dfa420757b
9 changed files with 72 additions and 108 deletions

@ -906,12 +906,7 @@ var gql_actions ThreadActions = ThreadActions{
}
context = NewReadContext(ctx)
err = UseStates(context, server, NewLockMap(
NewLockInfo(server, []string{"signal"}),
), func(context *StateContext) error {
return server.Signal(context, NewSignal("server_started"))
})
err = server.Signal(context, server, NewStatusSignal("server_started", server.ID()))
if err != nil {
return "", err
}

@ -18,11 +18,6 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
return nil, err
}
err = ctx.Server.Allowed("signal", "", ctx.User)
if err != nil {
return nil, err
}
id, err := ExtractID(p, "id")
if err != nil {
return nil, err
@ -37,9 +32,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
return UseStates(context, ctx.User, NewLockInfo(node, []string{"signal"}), func(context *StateContext) error {
return node.Signal(context, AbortSignal)
})
return node.Signal(context, ctx.User, AbortSignal)
})
if err != nil {
return nil, err
@ -98,15 +91,8 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
}
err := node.Allowed("signal", "", ctx.User)
if err != nil {
return err
}
return UseStates(context, ctx.User, NewLockMap(NewLockInfo(node, []string{"start_child", "signal"})), func(context *StateContext) error {
signal = NewStartChildSignal(child_id, action)
return node.Signal(context, signal)
})
signal = NewStartChildSignal(child_id, action)
return node.Signal(context, ctx.User, signal)
})
if err != nil {
return nil, err

@ -11,11 +11,6 @@ var GQLQuerySelf = &graphql.Field{
return nil, err
}
err = ctx.Server.Allowed("read", "", ctx.User)
if err != nil {
return nil, err
}
return ctx.Server, nil
},
}
@ -28,11 +23,6 @@ var GQLQueryUser = &graphql.Field{
return nil, err
}
err = ctx.User.Allowed("read", "", ctx.User)
if err != nil {
return nil, err
}
return ctx.User, nil
},
}

@ -18,7 +18,7 @@ import (
)
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "thread"})
ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
@ -91,19 +91,16 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err)
context = NewReadContext(ctx)
err = UseStates(context, gql, NewLockInfo(gql, []string{"signal"}), func(context *StateContext) error {
err := gql.Signal(context, NewStatusSignal("child_linked", t1.ID()))
if err != nil {
return nil
}
return gql.Signal(context, StopSignal)
})
err = gql.Signal(context, gql, NewStatusSignal("child_linked", t1.ID()))
fatalErr(t, err)
context = NewReadContext(ctx)
err = gql.Signal(context, gql, StopSignal)
fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel")
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel")
context = NewReadContext(ctx)
err = UseStates(context, gql, LockList([]Node{gql, u1}, nil), func(context *StateContext) error {
@ -132,13 +129,13 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
gql_loaded.Signal(context, StopSignal)
gql_loaded.Signal(context, gql_loaded, StopSignal)
return err
})
err = ThreadLoop(ctx, gql_loaded.(Thread), "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
(*GraphTester)(t).WaitForStatus(ctx, update_channel_2, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
}
@ -178,14 +175,12 @@ func TestGQLAuth(t * testing.T) {
ctx.Log.Logf("test", "DONE")
}
context := NewReadContext(ctx)
err := UseStates(context, gql_t, NewLockInfo(gql_t, []string{"signal}"}), func(context *StateContext) error {
return thread.Signal(context, StopSignal)
})
err := thread.Signal(context, thread, StopSignal)
fatalErr(t, err)
}(done, gql_t)
go func(thread Thread){
(*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start")
(*GraphTester)(t).WaitForStatus(ctx, update_channel, "server_started", 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port)

@ -13,7 +13,7 @@ import (
type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond
func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, signal_type string, timeout time.Duration, str string) GraphSignal {
func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, status string, timeout time.Duration, str string) GraphSignal {
timeout_channel := time.After(timeout)
for true {
select {
@ -22,8 +22,16 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == signal_type {
return signal
if signal.Type() == "status" {
sig, ok := signal.(StatusSignal)
if ok == true {
if sig.Status == status {
return signal
}
ctx.Log.Logf("test", "Different status received: %s", sig.Status)
} else {
ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal)
}
}
case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)

@ -216,22 +216,20 @@ func (lockable * SimpleLockable) CanUnlock(new_owner Lockable) error {
}
// Assumed that lockable is already locked for signal
func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, signal)
func (lockable * SimpleLockable) Signal(context *StateContext, princ Node, signal GraphSignal) error {
err := lockable.GraphNode.Signal(context, princ, signal)
if err != nil {
return err
}
switch signal.Direction() {
case Up:
err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"dependencies", "owner"}),
LockList(lockable.requirements, []string{"signal"}),
), func(context *StateContext) error {
err = UseStates(context, lockable,
NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false
for _, dependency := range(lockable.dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
dependency.Signal(context, signal)
dependency.Signal(context, lockable, signal)
if lockable.owner != nil {
if dependency.ID() == lockable.owner.ID() {
owner_sent = true
@ -241,20 +239,15 @@ func (lockable * SimpleLockable) Signal(context *StateContext, signal GraphSigna
if lockable.owner != nil && owner_sent == false {
if lockable.owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.owner.ID())
return UseStates(context, lockable, NewLockMap(NewLockInfo(lockable.owner, []string{"signal"})), func(context *StateContext) error {
return lockable.owner.Signal(context, signal)
})
return lockable.owner.Signal(context, lockable, signal)
}
}
return nil
})
case Down:
err = UseStates(context, lockable, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}),
LockList(lockable.requirements, []string{"signal"}),
), func(context *StateContext) error {
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.requirements) {
err := requirement.Signal(context, signal)
err := requirement.Signal(context, lockable, signal)
if err != nil {
return err
}

@ -69,12 +69,12 @@ type Node interface {
ID() NodeID
Type() NodeType
Allowed(action string, resource string, principal Node) error
Allowed(context *StateContext, action string, resource string, principal Node) error
AddPolicy(Policy) error
RemovePolicy(Policy) error
// Send a GraphSignal to the node, requires that the node is locked for read so that it can propagate
Signal(context *StateContext, signal GraphSignal) error
Signal(context *StateContext, princ Node, signal GraphSignal) error
// Register a channel to receive updates sent to the node
RegisterChannel(id NodeID, listener chan GraphSignal)
// Unregister a channel from receiving updates sent to the node
@ -100,19 +100,22 @@ func (node * GraphNode) Serialize() ([]byte, error) {
return json.MarshalIndent(&node_json, "", " ")
}
func (node *GraphNode) Allowed(action string, resource string, principal Node) error {
if principal == nil {
func (node *GraphNode) Allowed(context *StateContext, resource string, action string, princ Node) error {
if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == principal.ID() {
if node.ID() == princ.ID() {
return nil
}
for _, policy := range(node.policies) {
if policy.Allows(resource, action, principal) == true {
if policy.Allows(resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return nil
}
}
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", principal.ID().String(), resource, action, node.ID().String())
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
}
func (node *GraphNode) AddPolicy(policy Policy) error {
@ -193,8 +196,17 @@ func (node * GraphNode) Type() NodeType {
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it
func (node * GraphNode) Signal(context *StateContext, signal GraphSignal) error {
func (node * GraphNode) Signal(context *StateContext, princ Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
err := UseStates(context, princ, NewLockInfo(princ, nil), func(context *StateContext) error {
return node.Allowed(context, "signal", signal.Type(), princ)
})
if err != nil {
return nil
}
node.listeners_lock.Lock()
defer node.listeners_lock.Unlock()
closed := []NodeID{}
@ -536,10 +548,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
return err
}
final := false
if context.Started == false {
context.Started = true
final = true
}
new_locks := []Node{}
@ -588,18 +598,14 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
}
if already_granted == false {
err := node.Allowed("read", resource, princ)
err := node.Allowed(context, resource, "read", princ)
if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.read", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.RUnlock()
}
return err
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.read", princ.ID().String(), id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.read", princ.ID().String(), id.String(), resource)
}
}
new_permissions[id] = node_permissions
@ -621,10 +627,6 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
node.RUnlock()
}
if final == true {
context.Finished = true
}
return err
}
@ -692,18 +694,14 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
}
if already_granted == false {
err := node.Allowed("write", resource, princ)
err := node.Allowed(context, resource, "write", princ)
if err != nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.write", princ.ID().String(), id.String(), resource)
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Unlock()
}
return err
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.write", princ.ID().String(), id.String(), resource)
} else {
context.Graph.Log.Logf("policy", "POLICY_ALREADY_GRANTED: %s %s.%s.write", princ.ID().String(), id.String(), resource)
}
}
new_permissions[id] = node_permissions

@ -76,8 +76,16 @@ func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal {
}
}
func NewStatusSignal(_type string, source NodeID) IDSignal {
return NewIDSignal(_type, Up, source)
type StatusSignal struct {
IDSignal
Status string
}
func NewStatusSignal(status string, source NodeID) StatusSignal {
return StatusSignal{
IDSignal: NewIDSignal("status", Up, source),
Status: status,
}
}
type StartChildSignal struct {

@ -10,8 +10,8 @@ import (
)
// Assumed that thread is already locked for signal
func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, signal)
func (thread *SimpleThread) Signal(context *StateContext, princ Node, signal GraphSignal) error {
err := thread.SimpleLockable.Signal(context, princ, signal)
if err != nil {
return err
}
@ -20,20 +20,15 @@ func (thread *SimpleThread) Signal(context *StateContext, signal GraphSignal) er
case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.parent != nil {
return UseStates(context, thread, NewLockInfo(thread.parent, []string{"signal"}), func(context *StateContext) error {
return thread.parent.Signal(context, signal)
})
return thread.parent.Signal(context, thread, signal)
} else {
return nil
}
})
case Down:
err = UseStates(context, thread, NewLockMap(
NewLockInfo(thread, []string{"children"}),
LockList(thread.children, []string{"signal"}),
), func(context *StateContext) error {
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, child := range(thread.children) {
err := child.Signal(context, signal)
err := child.Signal(context, thread, signal)
if err != nil {
return err
}
@ -690,9 +685,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewStatusSignal("aborted", thread.ID()))
})
err := thread.Signal(context, thread, NewStatusSignal("aborted", thread.ID()))
if err != nil {
return "", err
}
@ -702,9 +695,7 @@ func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, erro
// Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
context := NewReadContext(ctx)
err := UseStates(context, thread, NewLockInfo(thread, []string{"signal"}), func(context *StateContext) error {
return thread.Signal(context, NewSignal("stopped"))
})
err := thread.Signal(context, thread, NewStatusSignal("stopped", thread.ID()))
return "finish", err
}