Added naive locking sequence with short test

gql_cataclysm
noah metz 2023-07-27 18:08:43 -06:00
parent 78c29d2f74
commit 98893de442
8 changed files with 163 additions and 79 deletions

@ -10,7 +10,7 @@ import (
) )
func TestGQLDB(t * testing.T) { func TestGQLDB(t * testing.T) {
ctx := logTestContext(t, []string{"loop", "node", "thread", "test", "signal", "policy", "db"}) ctx := logTestContext(t, []string{})
TestUserNodeType := NodeType("TEST_USER") TestUserNodeType := NodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})

@ -13,6 +13,35 @@ import (
type GraphTester testing.T type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond const listner_timeout = 50 * time.Millisecond
func (t * GraphTester) WaitForLinkState(ctx * Context, listener *ListenerExt, state string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener.Chan:
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == LinkSignalType {
sig, ok := signal.(LinkSignal)
if ok == true {
ctx.Log.Logf("test", "Link state received: %s", sig.State)
if sig.State == state {
return signal
}
} else {
ctx.Log.Logf("test", "Failed to cast signal to LinkSignal: %+v", signal)
}
}
case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(str)
return nil
}
}
return nil
}
func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal { func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout) timeout_channel := time.After(timeout)
for true { for true {
@ -57,7 +86,7 @@ func (t * GraphTester) CheckForNone(listener *ListenerExt, str string) {
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
policy := NewAllNodesPolicy([]string{"signal.status", "requirements.write", "requirements.read", "dependencies.write", "dependencies.read", "owner.read", "owner.write"}) policy := NewAllNodesPolicy([]SignalType{SignalType("status")})
listener_extension := NewListenerExt(buffer) listener_extension := NewListenerExt(buffer)
listener := NewNode(ctx, listener := NewNode(ctx,
RandID(), RandID(),

@ -2,6 +2,7 @@ package graphvent
import ( import (
"encoding/json" "encoding/json"
"fmt"
) )
type ListenerExt struct { type ListenerExt struct {
@ -41,15 +42,20 @@ func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signa
return return
} }
func (ext *ListenerExt) Serialize() ([]byte, error) { func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
return json.MarshalIndent(ext.Buffer, "", " ") var ext LockableExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
} }
type LockableExt struct { ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext)
Owner *NodeID `json:"owner"`
Requirements []NodeID `json:"requirements"` return &ext, nil
Dependencies []NodeID `json:"dependencies"` }
LocksHeld map[NodeID]*NodeID `json:"locks_held"`
func (ext *ListenerExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext.Buffer, "", " ")
} }
const LockableExtType = ExtType("LOCKABLE") const LockableExtType = ExtType("LOCKABLE")
@ -61,7 +67,15 @@ func (ext *LockableExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ") return json.MarshalIndent(ext, "", " ")
} }
func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID, locks_held map[NodeID]*NodeID) *LockableExt { func NewLockableExt(owner *NodeID, requirements map[NodeID]string, dependencies map[NodeID]string, locks_held map[NodeID]*NodeID) *LockableExt {
if requirements == nil {
requirements = map[NodeID]string{}
}
if dependencies == nil {
dependencies = map[NodeID]string{}
}
if locks_held == nil { if locks_held == nil {
locks_held = map[NodeID]*NodeID{} locks_held = map[NodeID]*NodeID{}
} }
@ -74,22 +88,79 @@ func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID,
} }
} }
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { type LockableExt struct {
var ext LockableExt Owner *NodeID `json:"owner"`
err := json.Unmarshal(data, &ext) Requirements map[NodeID]string `json:"requirements"`
if err != nil { Dependencies map[NodeID]string `json:"dependencies"`
return nil, err LocksHeld map[NodeID]*NodeID `json:"locks_held"`
} }
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext) func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error {
dep_ext, err := GetExt[*LockableExt](dependency)
if err != nil {
return err
}
return &ext, nil _, exists := dep_ext.Requirements[requirement]
if exists == true {
return fmt.Errorf("%s is already a requirement of %s", requirement, dependency.ID)
} }
_, exists = dep_ext.Dependencies[requirement]
if exists == true {
return fmt.Errorf("%s is a dependency of %s, cannot link as requirement", requirement, dependency.ID)
}
dep_ext.Requirements[requirement] = "start"
return ctx.Send(dependency.ID, requirement, NewLinkSignal("req_link"))
}
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal LinkSignal) { func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal LinkSignal) {
ctx.Log.Logf("lockable", "LINK_SIGNAL: %+v", signal) ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.State
switch state {
// sent by a node to link this node as a requirement
case "req_link":
_, exists := ext.Requirements[source]
if exists == false {
dep_state, exists := ext.Dependencies[source]
if exists == false {
ext.Dependencies[source] = "start"
ctx.Send(node.ID, source, NewLinkSignal("dep_link"))
} else if dep_state == "start" {
ext.Dependencies[source] = "linked"
ctx.Send(node.ID, source, NewLinkSignal("dep_linked"))
}
} else {
delete(ext.Requirements, source)
ctx.Send(node.ID, source, NewLinkSignal("req_reset"))
}
case "dep_link":
_, exists := ext.Dependencies[source]
if exists == false {
req_state, exists := ext.Requirements[source]
if exists == false {
ext.Requirements[source] = "start"
ctx.Send(node.ID, source, NewLinkSignal("req_link"))
} else if req_state == "start" {
ext.Requirements[source] = "linked"
ctx.Send(node.ID, source, NewLinkSignal("req_linked"))
}
} else {
delete(ext.Dependencies, source)
ctx.Send(node.ID, source, NewLinkSignal("dep_reset"))
}
case "dep_reset":
ctx.Log.Logf("lockable", "%s reset %s dependency state", node.ID, source)
case "req_reset":
ctx.Log.Logf("lockable", "%s reset %s requirement state", node.ID, source)
case "dep_linked":
ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source)
case "req_linked":
ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source)
default:
ctx.Log.Logf("lockable", "LINK_ERROR: unknown state %s", state)
}
} }
func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
@ -98,7 +169,7 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
owner_sent := false owner_sent := false
for _, dependency := range(ext.Dependencies) { for dependency, _ := range(ext.Dependencies) {
err := ctx.Send(node.ID, dependency, signal) err := ctx.Send(node.ID, dependency, signal)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err) ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err)
@ -120,16 +191,16 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
} }
} }
case Down: case Down:
for _, requirement := range(ext.Requirements) { for requirement, _ := range(ext.Requirements) {
err := ctx.Send(node.ID, requirement, signal) err := ctx.Send(node.ID, requirement, signal)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err) ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err)
} }
} }
case Direct: case Direct:
switch sig := signal.(type) { switch signal.Type() {
case LinkSignal: case LinkSignalType:
ext.HandleLinkSignal(ctx, source, node, sig) ext.HandleLinkSignal(ctx, source, node, signal.(LinkSignal))
default: default:
} }
default: default:

@ -2,11 +2,12 @@ package graphvent
import ( import (
"testing" "testing"
"time"
) )
const TestLockableType = NodeType("TEST_LOCKABLE") const TestLockableType = NodeType("TEST_LOCKABLE")
func lockableTestContext(t *testing.T) *Context { func lockableTestContext(t *testing.T) *Context {
ctx := logTestContext(t, []string{"lockable", "signal"}) ctx := logTestContext(t, []string{"lockable", "test"})
err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType, ListenerExtType}) err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType, ListenerExtType})
fatalErr(t, err) fatalErr(t, err)
@ -15,9 +16,9 @@ func lockableTestContext(t *testing.T) *Context {
} }
var link_policy = NewAllNodesPolicy([]string{"link", "status"}) var link_policy = NewAllNodesPolicy([]SignalType{LinkSignalType})
func Test(t *testing.T) { func TestLinkStatus(t *testing.T) {
ctx := lockableTestContext(t) ctx := lockableTestContext(t)
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
@ -33,6 +34,10 @@ func Test(t *testing.T) {
NewLockableExt(nil, nil, nil, nil), NewLockableExt(nil, nil, nil, nil),
) )
ctx.Send(l1.ID, l2.ID, NewLinkSignal("start", l1.ID)) // Link l2 as a requirement of l1
} err := LinkRequirement(ctx, l1, l2.ID)
fatalErr(t, err)
(*GraphTester)(t).WaitForLinkState(ctx, l1_listener, "dep_link", time.Millisecond*100, "No dep_link")
(*GraphTester)(t).WaitForLinkState(ctx, l2_listener, "req_linked", time.Millisecond*100, "No req_linked")
}

@ -156,7 +156,7 @@ func NodeLoop(ctx *Context, node *Node) error {
case msg := <- node.MsgChan: case msg := <- node.MsgChan:
signal = msg.Signal signal = msg.Signal
source = msg.Source source = msg.Source
err := Allowed(ctx, msg.Source, string(signal.Type()), node) err := Allowed(ctx, msg.Source, signal.Type(), node)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
continue continue
@ -307,7 +307,7 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, queued_signals []Queue
return node return node
} }
func Allowed(ctx *Context, principal_id NodeID, action string, node *Node) error { func Allowed(ctx *Context, principal_id NodeID, action SignalType, node *Node) error {
ctx.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) ctx.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action)
// Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension
if principal_id == node.ID { if principal_id == node.ID {

@ -5,17 +5,13 @@ import (
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node", "policy"}) ctx := logTestContext(t, []string{})
node_type := NodeType("test") node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
node := NewNode(ctx, RandID(), node_type, nil, NewGroupExt(nil)) node := NewNode(ctx, RandID(), node_type, nil, NewGroupExt(nil))
ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %+v", ser)
fatalErr(t, err)
ctx.Nodes = NodeMap{} ctx.Nodes = NodeMap{}
_, err = LoadNode(ctx, node.ID) _, err = LoadNode(ctx, node.ID)
fatalErr(t, err) fatalErr(t, err)

@ -7,15 +7,15 @@ import (
type Policy interface { type Policy interface {
Serializable[PolicyType] Serializable[PolicyType]
Allows(principal_id NodeID, action string, node *Node) error Allows(principal_id NodeID, action SignalType, node *Node) error
} }
//TODO: Update with change from principal *Node to principal_id so sane policies can still be made //TODO: Update with change from principal *Node to principal_id so sane policies can still be made
func (policy *AllNodesPolicy) Allows(principal_id NodeID, action string, node *Node) error { func (policy *AllNodesPolicy) Allows(principal_id NodeID, action SignalType, node *Node) error {
return policy.Actions.Allows(action) return policy.Actions.Allows(action)
} }
func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *Node) error { func (policy *PerNodePolicy) Allows(principal_id NodeID, action SignalType, node *Node) error {
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
if id != principal_id { if id != principal_id {
continue continue
@ -29,13 +29,13 @@ func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *No
return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID)
} }
func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action string, node *Node) error { func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action SignalType, node *Node) error {
lockable_ext, err := GetExt[*LockableExt](node) lockable_ext, err := GetExt[*LockableExt](node)
if err != nil { if err != nil {
return err return err
} }
for _, id := range(lockable_ext.Requirements) { for id, _ := range(lockable_ext.Requirements) {
if id == principal_id { if id == principal_id {
return policy.Actions.Allows(action) return policy.Actions.Allows(action)
} }
@ -58,9 +58,9 @@ func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy {
} }
} }
type Actions []string type Actions []SignalType
func (actions Actions) Allows(action string) error { func (actions Actions) Allows(action SignalType) error {
for _, a := range(actions) { for _, a := range(actions) {
if a == action { if a == action {
return nil return nil
@ -90,28 +90,12 @@ func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []b
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) { func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){ return func(ctx *Context, data []byte)(Policy, error){
var j PerNodePolicyJSON var policy PerNodePolicy
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &policy)
if err != nil {
return nil, err
}
node_actions := NodeActions{}
for id_str, actions := range(j.NodeActions) {
id, err := ParseID(id_str)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return init_fn(policy.NodeActions)
_, err = LoadNode(ctx, id)
if err != nil {
return nil, err
}
node_actions[id] = actions
}
return init_fn(node_actions)
} }
} }
@ -126,11 +110,7 @@ func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
} }
type PerNodePolicy struct { type PerNodePolicy struct {
NodeActions NodeActions NodeActions NodeActions `json:"node_actions"`
}
type PerNodePolicyJSON struct {
NodeActions map[string][]string `json:"node_actions"`
} }
const PerNodePolicyType = PolicyType("PER_NODE") const PerNodePolicyType = PolicyType("PER_NODE")
@ -139,14 +119,7 @@ func (policy *PerNodePolicy) Type() PolicyType {
} }
func (policy *PerNodePolicy) Serialize() ([]byte, error) { func (policy *PerNodePolicy) Serialize() ([]byte, error) {
node_actions := map[string][]string{} return json.MarshalIndent(policy, "", " ")
for id, actions := range(policy.NodeActions) {
node_actions[id.String()] = actions
}
return json.MarshalIndent(&PerNodePolicyJSON{
NodeActions: node_actions,
}, "", " ")
} }
func NewAllNodesPolicy(actions Actions) AllNodesPolicy { func NewAllNodesPolicy(actions Actions) AllNodesPolicy {
@ -294,7 +267,7 @@ func (ext *ACLExt) Type() ExtType {
} }
// Check if the extension allows the principal to perform action on node // Check if the extension allows the principal to perform action on node
func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action string, node *Node) error { func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action SignalType, node *Node) error {
ctx.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) ctx.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
errs := []error{} errs := []error{}
for _, policy := range(ext.Policies) { for _, policy := range(ext.Policies) {

@ -98,14 +98,24 @@ func NewStatusSignal(status string, source NodeID) StatusSignal {
} }
} }
const LinkSignalType = SignalType("LINK")
type LinkSignal struct { type LinkSignal struct {
IDSignal BaseSignal
State string `json:"state"` State string `json:"state"`
} }
func NewLinkSignal(state string, source NodeID) LinkSignal { func (signal LinkSignal) Serialize() ([]byte, error) {
return json.MarshalIndent(signal, "", " ")
}
func (signal LinkSignal) String() string {
ser, _ := signal.Serialize()
return string(ser)
}
func NewLinkSignal(state string) LinkSignal {
return LinkSignal{ return LinkSignal{
IDSignal: NewIDSignal("link", Direct, source), BaseSignal: NewDirectSignal(LinkSignalType),
State: state, State: state,
} }
} }