diff --git a/gql_test.go b/gql_test.go index e775998..bb19016 100644 --- a/gql_test.go +++ b/gql_test.go @@ -10,7 +10,7 @@ import ( ) func TestGQLDB(t * testing.T) { - ctx := logTestContext(t, []string{"loop", "node", "thread", "test", "signal", "policy", "db"}) + ctx := logTestContext(t, []string{}) TestUserNodeType := NodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) diff --git a/graph_test.go b/graph_test.go index 7804a0d..b4be029 100644 --- a/graph_test.go +++ b/graph_test.go @@ -13,6 +13,35 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond +func (t * GraphTester) WaitForLinkState(ctx * Context, listener *ListenerExt, state string, timeout time.Duration, str string) Signal { + timeout_channel := time.After(timeout) + for true { + select { + case signal := <- listener.Chan: + if signal == nil { + ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) + t.Fatal(str) + } + if signal.Type() == LinkSignalType { + sig, ok := signal.(LinkSignal) + if ok == true { + ctx.Log.Logf("test", "Link state received: %s", sig.State) + if sig.State == state { + return signal + } + } else { + ctx.Log.Logf("test", "Failed to cast signal to LinkSignal: %+v", signal) + } + } + case <-timeout_channel: + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + t.Fatal(str) + return nil + } + } + return nil +} + func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal { timeout_channel := time.After(timeout) for true { @@ -57,7 +86,7 @@ func (t * GraphTester) CheckForNone(listener *ListenerExt, str string) { const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { - policy := NewAllNodesPolicy([]string{"signal.status", "requirements.write", "requirements.read", "dependencies.write", "dependencies.read", "owner.read", "owner.write"}) + policy := NewAllNodesPolicy([]SignalType{SignalType("status")}) listener_extension := NewListenerExt(buffer) listener := NewNode(ctx, RandID(), diff --git a/lockable.go b/lockable.go index d9988bd..895eaab 100644 --- a/lockable.go +++ b/lockable.go @@ -2,6 +2,7 @@ package graphvent import ( "encoding/json" + "fmt" ) type ListenerExt struct { @@ -41,15 +42,20 @@ func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signa return } -func (ext *ListenerExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext.Buffer, "", " ") +func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { + var ext LockableExt + err := json.Unmarshal(data, &ext) + if err != nil { + return nil, err + } + + ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext) + + return &ext, nil } -type LockableExt struct { - Owner *NodeID `json:"owner"` - Requirements []NodeID `json:"requirements"` - Dependencies []NodeID `json:"dependencies"` - LocksHeld map[NodeID]*NodeID `json:"locks_held"` +func (ext *ListenerExt) Serialize() ([]byte, error) { + return json.MarshalIndent(ext.Buffer, "", " ") } const LockableExtType = ExtType("LOCKABLE") @@ -61,7 +67,15 @@ func (ext *LockableExt) Serialize() ([]byte, error) { return json.MarshalIndent(ext, "", " ") } -func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID, locks_held map[NodeID]*NodeID) *LockableExt { +func NewLockableExt(owner *NodeID, requirements map[NodeID]string, dependencies map[NodeID]string, locks_held map[NodeID]*NodeID) *LockableExt { + if requirements == nil { + requirements = map[NodeID]string{} + } + + if dependencies == nil { + dependencies = map[NodeID]string{} + } + if locks_held == nil { locks_held = map[NodeID]*NodeID{} } @@ -74,22 +88,79 @@ func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID, } } -func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { - var ext LockableExt - err := json.Unmarshal(data, &ext) +type LockableExt struct { + Owner *NodeID `json:"owner"` + Requirements map[NodeID]string `json:"requirements"` + Dependencies map[NodeID]string `json:"dependencies"` + LocksHeld map[NodeID]*NodeID `json:"locks_held"` +} + +func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { + dep_ext, err := GetExt[*LockableExt](dependency) if err != nil { - return nil, err + return err } - ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext) - - return &ext, nil -} + _, exists := dep_ext.Requirements[requirement] + if exists == true { + return fmt.Errorf("%s is already a requirement of %s", requirement, dependency.ID) + } + _, exists = dep_ext.Dependencies[requirement] + if exists == true { + return fmt.Errorf("%s is a dependency of %s, cannot link as requirement", requirement, dependency.ID) + } + dep_ext.Requirements[requirement] = "start" + return ctx.Send(dependency.ID, requirement, NewLinkSignal("req_link")) +} func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal LinkSignal) { - ctx.Log.Logf("lockable", "LINK_SIGNAL: %+v", signal) + ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) + state := signal.State + switch state { + // sent by a node to link this node as a requirement + case "req_link": + _, exists := ext.Requirements[source] + if exists == false { + dep_state, exists := ext.Dependencies[source] + if exists == false { + ext.Dependencies[source] = "start" + ctx.Send(node.ID, source, NewLinkSignal("dep_link")) + } else if dep_state == "start" { + ext.Dependencies[source] = "linked" + ctx.Send(node.ID, source, NewLinkSignal("dep_linked")) + } + } else { + delete(ext.Requirements, source) + ctx.Send(node.ID, source, NewLinkSignal("req_reset")) + } + case "dep_link": + _, exists := ext.Dependencies[source] + if exists == false { + req_state, exists := ext.Requirements[source] + if exists == false { + ext.Requirements[source] = "start" + ctx.Send(node.ID, source, NewLinkSignal("req_link")) + } else if req_state == "start" { + ext.Requirements[source] = "linked" + ctx.Send(node.ID, source, NewLinkSignal("req_linked")) + } + } else { + delete(ext.Dependencies, source) + ctx.Send(node.ID, source, NewLinkSignal("dep_reset")) + } + case "dep_reset": + ctx.Log.Logf("lockable", "%s reset %s dependency state", node.ID, source) + case "req_reset": + ctx.Log.Logf("lockable", "%s reset %s requirement state", node.ID, source) + case "dep_linked": + ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source) + case "req_linked": + ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source) + default: + ctx.Log.Logf("lockable", "LINK_ERROR: unknown state %s", state) + } } func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { @@ -98,7 +169,7 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal switch signal.Direction() { case Up: owner_sent := false - for _, dependency := range(ext.Dependencies) { + for dependency, _ := range(ext.Dependencies) { err := ctx.Send(node.ID, dependency, signal) if err != nil { ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err) @@ -120,16 +191,16 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal } } case Down: - for _, requirement := range(ext.Requirements) { + for requirement, _ := range(ext.Requirements) { err := ctx.Send(node.ID, requirement, signal) if err != nil { ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err) } } case Direct: - switch sig := signal.(type) { - case LinkSignal: - ext.HandleLinkSignal(ctx, source, node, sig) + switch signal.Type() { + case LinkSignalType: + ext.HandleLinkSignal(ctx, source, node, signal.(LinkSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index 9177bb2..6437efb 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -2,11 +2,12 @@ package graphvent import ( "testing" + "time" ) const TestLockableType = NodeType("TEST_LOCKABLE") func lockableTestContext(t *testing.T) *Context { - ctx := logTestContext(t, []string{"lockable", "signal"}) + ctx := logTestContext(t, []string{"lockable", "test"}) err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType, ListenerExtType}) fatalErr(t, err) @@ -15,9 +16,9 @@ func lockableTestContext(t *testing.T) *Context { } -var link_policy = NewAllNodesPolicy([]string{"link", "status"}) +var link_policy = NewAllNodesPolicy([]SignalType{LinkSignalType}) -func Test(t *testing.T) { +func TestLinkStatus(t *testing.T) { ctx := lockableTestContext(t) l1_listener := NewListenerExt(10) @@ -33,6 +34,10 @@ func Test(t *testing.T) { NewLockableExt(nil, nil, nil, nil), ) - ctx.Send(l1.ID, l2.ID, NewLinkSignal("start", l1.ID)) -} + // Link l2 as a requirement of l1 + err := LinkRequirement(ctx, l1, l2.ID) + fatalErr(t, err) + (*GraphTester)(t).WaitForLinkState(ctx, l1_listener, "dep_link", time.Millisecond*100, "No dep_link") + (*GraphTester)(t).WaitForLinkState(ctx, l2_listener, "req_linked", time.Millisecond*100, "No req_linked") +} diff --git a/node.go b/node.go index d4d0621..9b82474 100644 --- a/node.go +++ b/node.go @@ -156,7 +156,7 @@ func NodeLoop(ctx *Context, node *Node) error { case msg := <- node.MsgChan: signal = msg.Signal source = msg.Source - err := Allowed(ctx, msg.Source, string(signal.Type()), node) + err := Allowed(ctx, msg.Source, signal.Type(), node) if err != nil { ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) continue @@ -307,7 +307,7 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, queued_signals []Queue return node } -func Allowed(ctx *Context, principal_id NodeID, action string, node *Node) error { +func Allowed(ctx *Context, principal_id NodeID, action SignalType, node *Node) error { ctx.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension if principal_id == node.ID { diff --git a/node_test.go b/node_test.go index 384c875..9d77d47 100644 --- a/node_test.go +++ b/node_test.go @@ -5,17 +5,13 @@ import ( ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{"test", "db", "node", "policy"}) + ctx := logTestContext(t, []string{}) node_type := NodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) node := NewNode(ctx, RandID(), node_type, nil, NewGroupExt(nil)) - ser, err := node.Serialize() - ctx.Log.Logf("test", "NODE_SER: %+v", ser) - fatalErr(t, err) - ctx.Nodes = NodeMap{} _, err = LoadNode(ctx, node.ID) fatalErr(t, err) diff --git a/policy.go b/policy.go index f2e098d..b3017f5 100644 --- a/policy.go +++ b/policy.go @@ -7,15 +7,15 @@ import ( type Policy interface { Serializable[PolicyType] - Allows(principal_id NodeID, action string, node *Node) error + Allows(principal_id NodeID, action SignalType, node *Node) error } //TODO: Update with change from principal *Node to principal_id so sane policies can still be made -func (policy *AllNodesPolicy) Allows(principal_id NodeID, action string, node *Node) error { +func (policy *AllNodesPolicy) Allows(principal_id NodeID, action SignalType, node *Node) error { return policy.Actions.Allows(action) } -func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *Node) error { +func (policy *PerNodePolicy) Allows(principal_id NodeID, action SignalType, node *Node) error { for id, actions := range(policy.NodeActions) { if id != principal_id { continue @@ -29,13 +29,13 @@ func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *No return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) } -func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action string, node *Node) error { +func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action SignalType, node *Node) error { lockable_ext, err := GetExt[*LockableExt](node) if err != nil { return err } - for _, id := range(lockable_ext.Requirements) { + for id, _ := range(lockable_ext.Requirements) { if id == principal_id { return policy.Actions.Allows(action) } @@ -58,9 +58,9 @@ func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy { } } -type Actions []string +type Actions []SignalType -func (actions Actions) Allows(action string) error { +func (actions Actions) Allows(action SignalType) error { for _, a := range(actions) { if a == action { return nil @@ -90,28 +90,12 @@ func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []b func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) { return func(ctx *Context, data []byte)(Policy, error){ - var j PerNodePolicyJSON - err := json.Unmarshal(data, &j) + var policy PerNodePolicy + err := json.Unmarshal(data, &policy) if err != nil { return nil, err } - - node_actions := NodeActions{} - for id_str, actions := range(j.NodeActions) { - id, err := ParseID(id_str) - if err != nil { - return nil, err - } - - _, err = LoadNode(ctx, id) - if err != nil { - return nil, err - } - - node_actions[id] = actions - } - - return init_fn(node_actions) + return init_fn(policy.NodeActions) } } @@ -126,11 +110,7 @@ func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { } type PerNodePolicy struct { - NodeActions NodeActions -} - -type PerNodePolicyJSON struct { - NodeActions map[string][]string `json:"node_actions"` + NodeActions NodeActions `json:"node_actions"` } const PerNodePolicyType = PolicyType("PER_NODE") @@ -139,14 +119,7 @@ func (policy *PerNodePolicy) Type() PolicyType { } func (policy *PerNodePolicy) Serialize() ([]byte, error) { - node_actions := map[string][]string{} - for id, actions := range(policy.NodeActions) { - node_actions[id.String()] = actions - } - - return json.MarshalIndent(&PerNodePolicyJSON{ - NodeActions: node_actions, - }, "", " ") + return json.MarshalIndent(policy, "", " ") } func NewAllNodesPolicy(actions Actions) AllNodesPolicy { @@ -294,7 +267,7 @@ func (ext *ACLExt) Type() ExtType { } // Check if the extension allows the principal to perform action on node -func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action string, node *Node) error { +func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action SignalType, node *Node) error { ctx.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) errs := []error{} for _, policy := range(ext.Policies) { diff --git a/signal.go b/signal.go index ec25568..c63c95b 100644 --- a/signal.go +++ b/signal.go @@ -98,14 +98,24 @@ func NewStatusSignal(status string, source NodeID) StatusSignal { } } +const LinkSignalType = SignalType("LINK") type LinkSignal struct { - IDSignal + BaseSignal State string `json:"state"` } -func NewLinkSignal(state string, source NodeID) LinkSignal { +func (signal LinkSignal) Serialize() ([]byte, error) { + return json.MarshalIndent(signal, "", " ") +} + +func (signal LinkSignal) String() string { + ser, _ := signal.Serialize() + return string(ser) +} + +func NewLinkSignal(state string) LinkSignal { return LinkSignal{ - IDSignal: NewIDSignal("link", Direct, source), + BaseSignal: NewDirectSignal(LinkSignalType), State: state, } }