diff --git a/context.go b/context.go index 835e2ef..65d5b1b 100644 --- a/context.go +++ b/context.go @@ -74,9 +74,9 @@ type PolicyInfo struct { const ( ListenerExtType = ExtType("LISTENER") LockableExtType = ExtType("LOCKABLE") - GQLExtType = ExtType("GQL") - GroupExtType = ExtType("GROUP") - ECDHExtType = ExtType("ECDH") + GQLExtType = ExtType("GQL") + GroupExtType = ExtType("GROUP") + ECDHExtType = ExtType("ECDH") GQLNodeType = NodeType("GQL") ) diff --git a/ecdh.go b/ecdh.go index 837a856..028c99b 100644 --- a/ecdh.go +++ b/ecdh.go @@ -118,15 +118,15 @@ func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) state.SharedSecret = shared_secret ext.ECDHStates[source] = state log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - messages = messages.Add(log, node.ID, node.Key, &resp, source) + messages = messages.Add(node.ID, node.Key, &resp, source) } else { log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source) } case "resp": state, exists := ext.ECDHStates[source] if exists == false || state.ECKey == nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source) } else { err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) if err == nil { diff --git a/gql.go b/gql.go index 3617baa..34585be 100644 --- a/gql.go +++ b/gql.go @@ -869,12 +869,10 @@ func NewGQLExtContext() *GQLExtContext { return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val)) } - ids := make([]NodeID, len(id_strs)) - i := 0 + ids := []NodeID{} for id, state := range(id_strs) { if state.Link == "linked" { - ids[i] = id - i++ + ids = append(ids, id) } } return ids, nil @@ -889,12 +887,10 @@ func NewGQLExtContext() *GQLExtContext { return nil, fmt.Errorf("can't parse dependencies %+v as string, %s", val, reflect.TypeOf(val)) } - ids := make([]NodeID, len(id_strs)) - i := 0 + ids := []NodeID{} for id, state := range(id_strs) { if state.Link == "linked" { - ids[i] = id - i++ + ids = append(ids, id) } } return ids, nil @@ -1035,7 +1031,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa if signal.Type() == ErrorSignalType { // TODO: Forward to resolver if waiting for it sig := signal.(*ErrorSignal) - response_chan := ext.FreeResponseChannel(sig.ReqID) + response_chan := ext.FreeResponseChannel(sig.ReqID()) if response_chan != nil { select { case response_chan <- sig: @@ -1049,7 +1045,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else if signal.Type() == ReadResultSignalType { sig := signal.(*ReadResultSignal) - response_chan := ext.FreeResponseChannel(sig.ReqID) + response_chan := ext.FreeResponseChannel(sig.ReqID()) if response_chan != nil { select { case response_chan <- sig: diff --git a/gql_query.go b/gql_query.go index 736537e..28544c3 100644 --- a/gql_query.go +++ b/gql_query.go @@ -51,7 +51,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) msgs := Messages{} - msgs = msgs.Add(ctx.Context.Log, ctx.Server.ID, ctx.Key, read_signal, id) + msgs = msgs.Add(ctx.Server.ID, ctx.Key, read_signal, id) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) resp_channels[read_signal.ID()] = response_chan @@ -69,7 +69,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout response, err := WaitForSignal(ctx.Context, response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ - return sig.ReqID == sig_id + return sig.ReqID() == sig_id }) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index 7c161b0..ac3e1f7 100644 --- a/gql_test.go +++ b/gql_test.go @@ -10,42 +10,80 @@ import ( "net" "crypto/tls" "crypto/x509" + "crypto/rand" + "crypto/ed25519" "bytes" ) func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) + ctx := logTestContext(t, []string{"test", "policy", "pending"}) TestNodeType := NodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) fatalErr(t, err) - policy := NewAllNodesPolicy(Actions{ - MakeAction(LinkSignalType, "+"), - MakeAction(LinkStartSignalType, "+"), - MakeAction(LockSignalType, "+"), - MakeAction(StatusSignalType, "+"), - MakeAction(ErrorSignalType, "+"), - MakeAction(ReadSignalType, "+"), - MakeAction(ReadResultSignalType, "+"), - MakeAction(GQLStateSignalType, "+"), + pub, gql_key, err := ed25519.GenerateKey(rand.Reader) + fatalErr(t, err) + gql_id := KeyID(pub) + + group_policy_1 := NewAllNodesPolicy(Tree{ + ReadSignalType.String(): Tree{ + GroupExtType.String(): Tree{ + "members": Tree{}, + }, + }, + ReadResultSignalType.String(): nil, + ErrorSignalType.String(): nil, + }) + + group_policy_2 := NewMemberOfPolicy(NodeRules{ + gql_id: Tree{ + LinkSignalType.String(): nil, + LinkStartSignalType.String(): nil, + LockSignalType.String(): nil, + StatusSignalType.String(): nil, + ReadSignalType.String(): nil, + GQLStateSignalType.String(): nil, + }, + }) + + user_policy_1 := NewAllNodesPolicy(Tree{ + ReadResultSignalType.String(): nil, + ErrorSignalType.String(): nil, + }) + + user_policy_2 := NewMemberOfPolicy(NodeRules{ + gql_id: Tree{ + LinkSignalType.String(): nil, + ReadSignalType.String(): nil, + }, }) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") fatalErr(t, err) + listener_ext := NewListenerExt(10) - gql := NewNode(ctx, nil, GQLNodeType, 10, map[PolicyType]Policy{ - AllNodesPolicyType: &policy, - }, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext) n1 := NewNode(ctx, nil, TestNodeType, 10, map[PolicyType]Policy{ - AllNodesPolicyType: &policy, + MemberOfPolicyType: &user_policy_2, + AllNodesPolicyType: &user_policy_1, }, NewLockableExt()) + gql := NewNode(ctx, gql_key, GQLNodeType, 10, map[PolicyType]Policy{ + MemberOfPolicyType: &group_policy_2, + AllNodesPolicyType: &group_policy_1, + }, NewLockableExt(), gql_ext, NewGroupExt(map[NodeID]string{ + n1.ID: "user", + gql_id: "self", + }), listener_ext) + + ctx.Log.Logf("test", "GQL: %s", gql.ID) + ctx.Log.Logf("test", "NODE: %s", n1.ID) + err = LinkRequirement(ctx, gql, n1.ID) fatalErr(t, err) msgs := Messages{} - msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID) + msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) @@ -102,7 +140,7 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "RESP_2: %s", resp_2) msgs = Messages{} - msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) + msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { @@ -131,7 +169,7 @@ func TestGQLDB(t *testing.T) { ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) msgs := Messages{} - msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StopSignal, gql.ID) + msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { @@ -147,13 +185,13 @@ func TestGQLDB(t *testing.T) { ctx.Log.Logf("test", "SER_3: \n%s\n\n", ser3) // Clear all loaded nodes from the context so it loads them from the database - ctx.Nodes = NodeMap{} + ctx.Nodes = map[NodeID]*Node{} gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) listener_ext, err = GetExt[*ListenerExt](gql_loaded) fatalErr(t, err) msgs = Messages{} - msgs = msgs.Add(ctx.Log, gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) + msgs = msgs.Add(gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { diff --git a/lockable.go b/lockable.go index 2b0c223..b5c4b94 100644 --- a/lockable.go +++ b/lockable.go @@ -139,21 +139,21 @@ func NewLockableExt() *LockableExt { // Send the signal to unlock a node from itself func UnlockLockable(ctx *Context, node *Node) error { msgs := Messages{} - msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("unlock"), node.ID) + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), node.ID) return ctx.Send(msgs) } // Send the signal to lock a node from itself func LockLockable(ctx *Context, node *Node) error { msgs := Messages{} - msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("lock"), node.ID) + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("lock"), node.ID) return ctx.Send(msgs) } // Setup a node to send the initial requirement link signal, then send the signal func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { msgs := Messages{} - msgs = msgs.Add(ctx.Log, dependency.ID, dependency.Key, NewLinkStartSignal("req", requirement), dependency.ID) + msgs = msgs.Add(dependency.ID, dependency.Key, NewLinkStartSignal("req", requirement), dependency.ID) return ctx.Send(msgs) } @@ -166,16 +166,16 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, switch state { case "unlock": if ext.Owner == nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source) } else if source != *ext.Owner { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source) } else if ext.PendingOwner == nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source) } else { if len(ext.Requirements) == 0 { ext.Owner = nil ext.PendingOwner = nil - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), source) + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), source) } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { @@ -185,22 +185,22 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } state.Lock = "unlocking" ext.Requirements[id] = state - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlock"), id) + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) } } if source != node.ID { - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocking"), source) } } } case "unlocking": state, exists := ext.Requirements[source] if exists == false { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "unlocking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } case "unlocked": @@ -210,11 +210,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "unlocking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } else { state.Lock = "unlocked" ext.Requirements[source] = state @@ -234,7 +234,7 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, if linked == unlocked { previous_owner := *ext.Owner ext.Owner = nil - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) } } } @@ -245,11 +245,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "locking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } else { state.Lock = "locked" ext.Requirements[source] = state @@ -268,31 +268,31 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, if linked == locked { ext.Owner = ext.PendingOwner - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) } } } case "locking": state, exists := ext.Requirements[source] if exists == false { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "locking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } case "lock": if ext.Owner != nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) } else if ext.PendingOwner != nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source) } else { owner := source if len(ext.Requirements) == 0 { ext.Owner = &owner ext.PendingOwner = ext.Owner - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), source) + messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), source) } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { @@ -303,11 +303,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } state.Lock = "locking" ext.Requirements[id] = state - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("lock"), id) + messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) } } if source != node.ID { - messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locking"), source) + messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source) } } } @@ -329,25 +329,25 @@ func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source Nod state, exists := ext.Requirements[target] _, dep_exists := ext.Dependencies[target] if ext.Owner != nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) } else if ext.Owner != ext.PendingOwner { if ext.PendingOwner == nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) } else { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) } } else if exists == true { if state.Link == "linking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking_req"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking_req"), source) } else if state.Link == "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_req"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_req"), source) } } else if dep_exists == true { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_dep"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_dep"), source) } else { ext.Requirements[target] = LinkState{"linking", "unlocked", source} - messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("linked_as_req"), target) - messages = messages.Add(log, node.ID, node.Key, NewLinkStartSignal("linking_req", target), source) + messages = messages.Add(node.ID, node.Key, NewLinkSignal("linked_as_req"), target) + messages = messages.Add(node.ID, node.Key, NewLinkStartSignal("linking_req", target), source) } } return messages @@ -364,7 +364,7 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, case "dep_done": state, exists := ext.Requirements[source] if exists == false { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linking"), source) } else if state.Link == "linking" { state.Link = "linked" ext.Requirements[source] = state @@ -374,16 +374,16 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, state, exists := ext.Dependencies[source] if exists == false { ext.Dependencies[source] = LinkState{"linked", "unlocked", source} - messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("dep_done"), source) + messages = messages.Add(node.ID, node.Key, NewLinkSignal("dep_done"), source) } else if state.Link == "linking" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking"), source) } else if state.Link == "linked" { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linked"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linked"), source) } else if ext.PendingOwner != ext.Owner { if ext.Owner == nil { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) } else { - messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) } } @@ -403,7 +403,7 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal owner_sent := false for dependency, state := range(ext.Dependencies) { if state.Link == "linked" { - messages = messages.Add(ctx.Log, node.ID, node.Key, signal, dependency) + messages = messages.Add(node.ID, node.Key, signal, dependency) if ext.Owner != nil { if dependency == *ext.Owner { owner_sent = true @@ -414,13 +414,13 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal if ext.Owner != nil && owner_sent == false { if *ext.Owner != node.ID { - messages = messages.Add(ctx.Log, node.ID, node.Key, signal, *ext.Owner) + messages = messages.Add(node.ID, node.Key, signal, *ext.Owner) } } case Down: for requirement, state := range(ext.Requirements) { if state.Link == "linked" { - messages = messages.Add(ctx.Log, node.ID, node.Key, signal, requirement) + messages = messages.Add(node.ID, node.Key, signal, requirement) } } case Direct: diff --git a/lockable_test.go b/lockable_test.go index 584fb96..bd75958 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -15,11 +15,6 @@ func lockableTestContext(t *testing.T, logs []string) *Context { return ctx } - -//TODO: add new finer grained signals, and probably add wildcards to not have to deal with annoying acl policies -var link_policy = NewAllNodesPolicy(Actions{MakeAction(LinkSignalType, "*"), MakeAction(StatusSignalType, "+")}) -var lock_policy = NewAllNodesPolicy(Actions{MakeAction(LockSignalType, "*")}) - func TestLink(t *testing.T) { ctx := lockableTestContext(t, []string{"lockable"}) @@ -44,7 +39,7 @@ func TestLink(t *testing.T) { fatalErr(t, err) msgs := Messages{} - msgs = msgs.Add(ctx.Log, l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID) + msgs = msgs.Add(l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID) err = ctx.Send(msgs) fatalErr(t, err) diff --git a/node.go b/node.go index 6a66c51..81f45f5 100644 --- a/node.go +++ b/node.go @@ -23,7 +23,7 @@ const ( // Total length of the node database header, has magic to verify and type_hash to map to load function NODE_DB_HEADER_LEN = 32 EXTENSION_DB_HEADER_LEN = 16 - QSIGNAL_DB_HEADER_LEN = 40 + QSIGNAL_DB_HEADER_LEN = 24 POLICY_DB_HEADER_LEN = 16 ) @@ -100,11 +100,27 @@ type Extension interface { // A QueuedSignal is a Signal that has been Queued to trigger at a set time type QueuedSignal struct { - uuid.UUID Signal time.Time } +type PendingACL struct { + Counter int + TimeoutID uuid.UUID + Action Tree + Principal NodeID + Messages Messages + Responses []Signal + Signal Signal + Source NodeID +} + +type PendingSignal struct { + Policy PolicyType + Found bool + ID uuid.UUID +} + // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { @@ -114,6 +130,9 @@ type Node struct { Extensions map[ExtType]Extension Policies map[PolicyType]Policy + PendingACLs map[uuid.UUID]PendingACL + PendingSignals map[uuid.UUID]PendingSignal + // Channel for this node to receive messages from the Context MsgChan chan *Message // Size of MsgChan @@ -127,27 +146,51 @@ type Node struct { NextSignal *QueuedSignal } -func (node *Node) Allows(principal_id NodeID, action Action)(Messages, error) { - errs := []error{} - var pends Messages = nil - for _, policy := range(node.Policies) { - msgs, err := policy.Allows(principal_id, action, node) - if err == nil { - return nil, nil - } - errs = append(errs, err) - if msgs != nil { - pends = append(pends, msgs...) +type RuleResult int +const ( + Allow RuleResult = iota + Deny + Pending +) + +func (node *Node) Allows(principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) { + pends := map[PolicyType]Messages{} + for policy_type, policy := range(node.Policies) { + msgs, resp := policy.Allows(principal_id, action, node) + if resp == Allow { + return nil, Allow + } else if resp == Pending { + pends[policy_type] = msgs } } - return pends, fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) + if len(pends) != 0 { + return pends, Pending + } + return nil, Deny +} + +func (node *Node) QueueSignal(time time.Time, signal Signal) { + node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time}) + node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) } -func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID { - id := uuid.New() - node.SignalQueue = append(node.SignalQueue, QueuedSignal{id, signal, time}) +func (node *Node) DequeueSignal(id uuid.UUID) error { + idx := -1 + for i, q := range(node.SignalQueue) { + if q.Signal.ID() == id { + idx = i + break + } + } + if idx == -1 { + return fmt.Errorf("%s is not in SignalQueue", id) + } + + node.SignalQueue[idx] = node.SignalQueue[len(node.SignalQueue)-1] + node.SignalQueue = node.SignalQueue[:len(node.SignalQueue)-1] node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) - return id + + return nil } func (node *Node) ClearSignalQueue() { @@ -229,11 +272,34 @@ func nodeLoop(ctx *Context, node *Node) error { continue } - _, err = node.Allows(KeyID(msg.Principal), msg.Signal.Permission()) - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s - %s - %e", node.ID, msg.Signal, err) - // TODO: send the msgs and set the state so that getting a response triggers a potential processing of the original signal - continue + princ_id := KeyID(msg.Principal) + if princ_id != node.ID { + pends, resp := node.Allows(princ_id, msg.Signal.Permission()) + if resp == Deny { + ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) + msgs := Messages{} + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), source) + ctx.Send(msgs) + continue + } else if resp == Pending { + ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends) + timeout_signal := NewACLTimeoutSignal(msg.Signal.ID()) + node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal) + msgs := Messages{} + for policy_type, sigs := range(pends) { + for _, m := range(sigs) { + msgs = append(msgs, m) + node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} + } + } + node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} + ctx.Send(msgs) + continue + } else if resp == Allow { + ctx.Log.Logf("policy", "SIGNAL_POLICY_ALLOW: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) + } + } else { + ctx.Log.Logf("policy", "SIGNAL_POLICY_SELF: %s - %s", node.ID, msg.Signal.Permission()) } signal = msg.Signal @@ -246,7 +312,7 @@ func nodeLoop(ctx *Context, node *Node) error { t := node.NextSignal.Time i := -1 for j, queued := range(node.SignalQueue) { - if queued.UUID == node.NextSignal.UUID { + if queued.Signal.ID() == node.NextSignal.Signal.ID() { i = j break } @@ -260,29 +326,73 @@ func nodeLoop(ctx *Context, node *Node) error { node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) if node.NextSignal == nil { - ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan) + ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan) } else { - ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time) + ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time) } } - ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) + ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) + + info, waiting := node.PendingSignals[signal.ReqID()] + if waiting == true { + if info.Found == false { + info.Found = true + node.PendingSignals[signal.ReqID()] = info + ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) + req_info, exists := node.PendingACLs[info.ID] + if exists == true { + req_info.Counter -= 1 + req_info.Responses = append(req_info.Responses, signal) + + // TODO: call the right policy ParseResponse to check if the updated state passes the ACL check + allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal) + if allowed == Allow { + ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) + signal = req_info.Signal + source = req_info.Source + err := node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else if req_info.Counter == 0 { + ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) + // Send the denied response + msgs := Messages{} + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) + err := ctx.Send(msgs) + if err != nil { + ctx.Log.Logf("signal", "SEND_ERR: %s", err) + } + err = node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else { + node.PendingACLs[info.ID] = req_info + continue + } + } + } + } // Handle special signal types if signal.Type() == StopSignalType { msgs := Messages{} - msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) ctx.Send(msgs) node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) break } else if signal.Type() == ReadSignalType { read_signal, ok := signal.(*ReadSignal) if ok == false { - ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", signal) + ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) } else { result := node.ReadFields(read_signal.Extensions) msgs := Messages{} - msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.Type, result), source) + msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source) ctx.Send(msgs) } } @@ -313,10 +423,10 @@ type Message struct { } type Messages []*Message -func (msgs Messages) Add(log Logger, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { +func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { msg, err := NewMessage(dest, source, principal, signal) if err != nil { - log.Logf("signal", "MESSAGE_CREATE_ERR: %s", err) + panic(err) } else { msgs = append(msgs, msg) } @@ -445,7 +555,6 @@ func (node *Node) Serialize() ([]byte, error) { node_db.QueuedSignals[i] = QSignalDB{ QSignalDBHeader{ - qsignal.Signal.ID(), qsignal.Time, Hash(qsignal.Signal.Type()), uint64(len(ser)), @@ -528,6 +637,9 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si Type: node_type, Extensions: ext_map, Policies: policies, + //TODO serialize/deserialize these + PendingACLs: map[uuid.UUID]PendingACL{}, + PendingSignals: map[uuid.UUID]PendingSignal{}, MsgChan: make(chan *Message, buffer_size), BufferSize: buffer_size, SignalQueue: []QueuedSignal{}, @@ -556,7 +668,6 @@ type PolicyDB struct { } type QSignalDBHeader struct { - SignalID uuid.UUID Time time.Time TypeHash uint64 Length uint64 @@ -671,21 +782,14 @@ func NewNodeDB(data []byte) (NodeDB, error) { cur := data[ptr:] // TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented // Right now causes segfault because any saved signal is loaded as nil - signal_id_bytes := cur[0:16] - unix_milli := binary.BigEndian.Uint64(cur[16:24]) - type_hash := binary.BigEndian.Uint64(cur[24:32]) - signal_size := binary.BigEndian.Uint64(cur[32:40]) - - signal_id, err := uuid.FromBytes(signal_id_bytes) - if err != nil { - return zero, err - } + unix_milli := binary.BigEndian.Uint64(cur[0:8]) + type_hash := binary.BigEndian.Uint64(cur[8:16]) + signal_size := binary.BigEndian.Uint64(cur[16:24]) signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)] queued_signals[i] = QSignalDB{ QSignalDBHeader{ - signal_id, time.UnixMilli(int64(unix_milli)), type_hash, signal_size, @@ -742,11 +846,9 @@ func (node NodeDB) Serialize() []byte { func (header QSignalDBHeader) Serialize() []byte { ret := make([]byte, QSIGNAL_DB_HEADER_LEN) - id_ser, _ := header.SignalID.MarshalBinary() - copy(ret, id_ser) - binary.BigEndian.PutUint64(ret[16:24], uint64(header.Time.UnixMilli())) - binary.BigEndian.PutUint64(ret[24:32], header.TypeHash) - binary.BigEndian.PutUint64(ret[32:40], header.Length) + binary.BigEndian.PutUint64(ret[0:8], uint64(header.Time.UnixMilli())) + binary.BigEndian.PutUint64(ret[8:16], header.TypeHash) + binary.BigEndian.PutUint64(ret[16:24], header.Length) return ret } @@ -862,7 +964,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } - signal_queue[i] = QueuedSignal{qsignal.Header.SignalID, signal, qsignal.Header.Time} + signal_queue[i] = QueuedSignal{signal, qsignal.Header.Time} } next_signal, timeout_chan := SoonestSignal(signal_queue) @@ -941,129 +1043,6 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return node, nil } -func NewACLInfo(node *Node, resources []string) ACLMap { - return ACLMap{ - node.ID: ACLInfo{ - Node: node, - Resources: resources, - }, - } -} - -func NewACLMap(requests ...ACLMap) ACLMap { - reqs := ACLMap{} - for _, req := range(requests) { - for id, info := range(req) { - reqs[id] = info - } - } - return reqs -} - -func ACLListM(m map[NodeID]*Node, resources[]string) ACLMap { - reqs := ACLMap{} - for _, node := range(m) { - reqs[node.ID] = ACLInfo{ - Node: node, - Resources: resources, - } - } - return reqs -} - -func ACLList(list []*Node, resources []string) ACLMap { - reqs := ACLMap{} - for _, node := range(list) { - reqs[node.ID] = ACLInfo{ - Node: node, - Resources: resources, - } - } - return reqs -} - -type NodeMap map[NodeID]*Node - -type ACLInfo struct { - Node *Node - Resources []string -} - -type ACLMap map[NodeID]ACLInfo -type ExtMap map[uint64]Extension - -// Context of running state usage(read/write) -type StateContext struct { - // Type of the state context - Type string - // The wrapped graph context - Graph *Context - // Granted permissions in the context - Permissions map[NodeID]ACLMap - // Locked extensions in the context - Locked map[NodeID]*Node - - // Context state for validation - Started bool - Finished bool -} - -func ValidateStateContext(context *StateContext, Type string, Finished bool) error { - if context == nil { - return fmt.Errorf("context is nil") - } - if context.Finished != Finished { - return fmt.Errorf("context in wrong Finished state") - } - if context.Type != Type { - return fmt.Errorf("%s is not a %s context", context.Type, Type) - } - if context.Locked == nil || context.Graph == nil || context.Permissions == nil { - return fmt.Errorf("context is not initialized correctly") - } - return nil -} - -func NewReadContext(ctx *Context) *StateContext { - return &StateContext{ - Type: "read", - Graph: ctx, - Permissions: map[NodeID]ACLMap{}, - Locked: map[NodeID]*Node{}, - Started: false, - Finished: false, - } -} - -func NewWriteContext(ctx *Context) *StateContext { - return &StateContext{ - Type: "write", - Graph: ctx, - Permissions: map[NodeID]ACLMap{}, - Locked: map[NodeID]*Node{}, - Started: false, - Finished: false, - } -} - -type StateFn func(*StateContext)(error) - -func del[K comparable](list []K, val K) []K { - idx := -1 - for i, v := range(list) { - if v == val { - idx = i - break - } - } - if idx == -1 { - return nil - } - - list[idx] = list[len(list)-1] - return list[:len(list)-1] -} - func IDMap[S any, T map[NodeID]S](m T)map[string]S { ret := map[string]S{} for id, val := range(m) { diff --git a/node_test.go b/node_test.go index d19558b..0f7bfca 100644 --- a/node_test.go +++ b/node_test.go @@ -15,7 +15,7 @@ func TestNodeDB(t *testing.T) { node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil)) - ctx.Nodes = NodeMap{} + ctx.Nodes = map[NodeID]*Node{} _, err = ctx.GetNode(node.ID) fatalErr(t, err) } @@ -46,7 +46,7 @@ func TestNodeRead(t *testing.T) { GroupExtType: []string{"members"}, }) msgs := Messages{} - msgs = msgs.Add(ctx.Log, n2.ID, n2.Key, read_sig, n1.ID) + msgs = msgs.Add(n2.ID, n2.Key, read_sig, n1.ID) err = ctx.Send(msgs) fatalErr(t, err) @@ -84,7 +84,7 @@ func TestECDH(t *testing.T) { fatalErr(t, err) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) msgs := Messages{} - msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_req, n2.ID) + msgs = msgs.Add(n1.ID, n1.Key, ecdh_req, n2.ID) err = ctx.Send(msgs) fatalErr(t, err) @@ -98,7 +98,7 @@ func TestECDH(t *testing.T) { fatalErr(t, err) msgs = Messages{} - msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_sig, n2.ID) + msgs = msgs.Add(n1.ID, n1.Key, ecdh_sig, n2.ID) err = ctx.Send(msgs) fatalErr(t, err) } diff --git a/policy.go b/policy.go index 045acd3..a5e92c1 100644 --- a/policy.go +++ b/policy.go @@ -2,127 +2,169 @@ package graphvent import ( "encoding/json" - "fmt" ) const ( - UserOfPolicyType = PolicyType("USER_OF") - RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") + MemberOfPolicyType = PolicyType("USER_OF") PerNodePolicyType = PolicyType("PER_NODE") AllNodesPolicyType = PolicyType("ALL_NODES") ) type Policy interface { Serializable[PolicyType] - Allows(principal_id NodeID, action Action, node *Node)(Messages, error) + Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) + ContinueAllows(current PendingACL, signal Signal)RuleResult // Merge with another policy of the same underlying type Merge(Policy) Policy // Make a copy of this policy Copy() Policy } -func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { - return nil, policy.Actions.Allows(action) +func (policy *AllNodesPolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { + return nil, policy.Rules.Allows(action) } -func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { - for id, actions := range(policy.NodeActions) { +func (policy *AllNodesPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { + return Deny +} + +func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { + for id, actions := range(policy.NodeRules) { if id != principal_id { continue } return nil, actions.Allows(action) } - return nil, fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) + return nil, Deny } -func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { - lockable_ext, err := GetExt[*LockableExt](node) - if err != nil { - return nil, err - } - - for id, _ := range(lockable_ext.Requirements) { - if id == principal_id { - return nil, policy.Actions.Allows(action) - } - } - - return nil, fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) +func (policy *PerNodePolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { + return Deny } -type UserOfPolicy struct { +type MemberOfPolicy struct { PerNodePolicy } -func (policy *UserOfPolicy) Type() PolicyType { - return UserOfPolicyType +func (policy *MemberOfPolicy) Type() PolicyType { + return MemberOfPolicyType +} + +func NewMemberOfPolicy(group_rules NodeRules) MemberOfPolicy { + return MemberOfPolicy{ + PerNodePolicy: NewPerNodePolicy(group_rules), + } } -func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy { - return UserOfPolicy{ - PerNodePolicy: NewPerNodePolicy(group_actions), +func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { + sig, ok := signal.(*ReadResultSignal) + if ok == false { + return Deny + } + + group_ext_data, ok := sig.Extensions[GroupExtType] + if ok == false { + return Deny + } + + members, ok := group_ext_data["members"].(map[NodeID]string) + if ok == false { + return Deny + } + + for member, _ := range(members) { + if member == current.Principal { + return policy.NodeRules[sig.NodeID].Allows(current.Action) + } } + + return Deny } // Send a read signal to Group to check if principal_id is a member of it -func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) (Messages, error) { - // Send a read signal to each of the groups in the map - // Check for principal_id in any of the returned member lists(skipping errors) - // Return an error in the default case - return nil, fmt.Errorf("NOT_IMPLEMENTED") +func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { + msgs := Messages{} + for id, rule := range(policy.NodeRules) { + if id == node.ID { + ext, err := GetExt[*GroupExt](node) + if err == nil { + for member, _ := range(ext.Members) { + if member == principal_id { + if rule.Allows(action) == Allow { + return nil, Allow + } + } + } + } + } else { + msgs = msgs.Add(node.ID, node.Key, NewReadSignal(map[ExtType][]string{ + GroupExtType: []string{"members"}, + }), id) + } + } + return msgs, Pending } -func (policy *UserOfPolicy) Merge(p Policy) Policy { - other := p.(*UserOfPolicy) - policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions) +func (policy *MemberOfPolicy) Merge(p Policy) Policy { + other := p.(*MemberOfPolicy) + policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) return policy } -func (policy *UserOfPolicy) Copy() Policy { - new_actions := CopyNodeActions(policy.NodeActions) - return &UserOfPolicy{ - PerNodePolicy: NewPerNodePolicy(new_actions), +func (policy *MemberOfPolicy) Copy() Policy { + new_rules := CopyNodeRules(policy.NodeRules) + return &MemberOfPolicy{ + PerNodePolicy: NewPerNodePolicy(new_rules), } } -type RequirementOfPolicy struct { - AllNodesPolicy -} -func (policy *RequirementOfPolicy) Type() PolicyType { - return RequirementOfPolicyType -} +func CopyTree(tree Tree) Tree { + if tree == nil { + return nil + } -func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy { - return RequirementOfPolicy{ - AllNodesPolicy: NewAllNodesPolicy(actions), + ret := Tree{} + for name, sub := range(tree) { + ret[name] = CopyTree(sub) } + + return ret } -func MergeActions(first Actions, second Actions) Actions { - ret := second - for _, action := range(first) { - ret = append(ret, action) +func MergeTrees(first Tree, second Tree) Tree { + if first == nil || second == nil { + return nil + } + + ret := CopyTree(first) + for name, sub := range(second) { + current, exists := ret[name] + if exists == true { + ret[name] = MergeTrees(current, sub) + } else { + ret[name] = CopyTree(sub) + } } return ret } -func CopyNodeActions(actions NodeActions) NodeActions { - ret := NodeActions{} - for id, a := range(actions) { - ret[id] = a +func CopyNodeRules(rules NodeRules) NodeRules { + ret := NodeRules{} + for id, r := range(rules) { + ret[id] = r } return ret } -func MergeNodeActions(first NodeActions, second NodeActions) NodeActions { - merged := NodeActions{} +func MergeNodeRules(first NodeRules, second NodeRules) NodeRules { + merged := NodeRules{} for id, actions := range(first) { merged[id] = actions } for id, actions := range(second) { existing, exists := merged[id] if exists { - merged[id] = MergeActions(existing, actions) + merged[id] = MergeTrees(existing, actions) } else { merged[id] = actions } @@ -132,134 +174,100 @@ func MergeNodeActions(first NodeActions, second NodeActions) NodeActions { func (policy *PerNodePolicy) Merge(p Policy) Policy { other := p.(*PerNodePolicy) - policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions) + policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) return policy } func (policy *PerNodePolicy) Copy() Policy { - new_actions := CopyNodeActions(policy.NodeActions) + new_rules := CopyNodeRules(policy.NodeRules) return &PerNodePolicy{ - NodeActions: new_actions, + NodeRules: new_rules, } } func (policy *AllNodesPolicy) Merge(p Policy) Policy { other := p.(*AllNodesPolicy) - policy.Actions = MergeActions(policy.Actions, other.Actions) + policy.Rules = MergeTrees(policy.Rules, other.Rules) return policy } func (policy *AllNodesPolicy) Copy() Policy { - new_actions := policy.Actions + new_rules := policy.Rules return &AllNodesPolicy { - Actions: new_actions, - } -} - -func (policy *RequirementOfPolicy) Merge(p Policy) Policy { - other := p.(*RequirementOfPolicy) - policy.Actions = MergeActions(policy.Actions, other.Actions) - return policy -} - -func (policy *RequirementOfPolicy) Copy() Policy { - new_actions := policy.Actions - return &RequirementOfPolicy{ - AllNodesPolicy { - Actions: new_actions, - }, + Rules: new_rules, } } -type Action []string - -func MakeAction(parts ...interface{}) Action { - action := make(Action, len(parts)) - for i, part := range(parts) { - stringer, ok := part.(fmt.Stringer) - if ok == false { - switch p := part.(type) { - case string: - action[i] = p - default: - panic("%s can not be part of an action") +type Tree map[string]Tree + +func (rule Tree) Allows(action Tree) RuleResult { + // If the current rule is nil, it's a wildcard and any action being processed is allowed + if rule == nil { + return Allow + // If the rule isn't "allow all" but the action is "request all", deny + } else if action == nil { + return Deny + // If the current action has no children, it's allowed + } else if len(action) == 0 { + return Allow + // If the current rule has no children but the action goes further, it's not allowed + } else if len(rule) == 0 { + return Deny + // If the current rule and action have children, all the children of action must be allowed by rule + } else { + for sub, subtree := range(action) { + subrule, exists := rule[sub] + if exists == false { + return Deny + } else if subrule.Allows(subtree) == Deny { + return Deny } - } else { - action[i] = stringer.String() } + return Allow } - return action } -func (action Action) Allows(test Action) bool { - action_len := len(action) - for i, part := range(test) { - if i >= action_len { - return false - } else if action[i] == part || action[i] == "*" { - continue - } else if action[i] == "+" { - break - } else { - return false - } - } - - return true -} +type NodeRules map[NodeID]Tree -type Actions []Action - -func (actions Actions) Allows(action Action) error { - for _, a := range(actions) { - if a.Allows(action) == true { - return nil - } - } - return fmt.Errorf("%s not in allows list", action) -} - -type NodeActions map[NodeID]Actions - -func (actions NodeActions) MarshalJSON() ([]byte, error) { - tmp := map[string]Actions{} - for id, a := range(actions) { - tmp[id.String()] = a +func (rules NodeRules) MarshalJSON() ([]byte, error) { + tmp := map[string]Tree{} + for id, r := range(rules) { + tmp[id.String()] = r } return json.Marshal(tmp) } -func (actions *NodeActions) UnmarshalJSON(data []byte) error { - tmp := map[string]Actions{} +func (rules *NodeRules) UnmarshalJSON(data []byte) error { + tmp := map[string]Tree{} err := json.Unmarshal(data, &tmp) if err != nil { return err } - for id_str, a := range(tmp) { + for id_str, r := range(tmp) { id, err := ParseID(id_str) if err != nil { return err } - ac := *actions - ac[id] = a + ru := *rules + ru[id] = r } return nil } -func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { +func NewPerNodePolicy(node_actions NodeRules) PerNodePolicy { if node_actions == nil { - node_actions = NodeActions{} + node_actions = NodeRules{} } return PerNodePolicy{ - NodeActions: node_actions, + NodeRules: node_actions, } } type PerNodePolicy struct { - NodeActions NodeActions `json:"node_actions"` + NodeRules NodeRules `json:"node_actions"` } func (policy *PerNodePolicy) Type() PolicyType { @@ -274,18 +282,14 @@ func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, policy) } -func NewAllNodesPolicy(actions Actions) AllNodesPolicy { - if actions == nil { - actions = Actions{} - } - +func NewAllNodesPolicy(rules Tree) AllNodesPolicy { return AllNodesPolicy{ - Actions: actions, + Rules: rules, } } type AllNodesPolicy struct { - Actions Actions + Rules Tree } func (policy *AllNodesPolicy) Type() PolicyType { @@ -300,10 +304,7 @@ func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, policy) } -var ErrorSignalAction = Action{"ERROR_RESP"} -var ReadResultSignalAction = Action{"READ_RESULT"} -var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} -var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}) -var DefaultACLPolicies = []Policy{ - &defaultPolicy, -} +var DefaultPolicy = NewAllNodesPolicy(Tree{ + ErrorSignalType.String(): nil, + ReadResultSignalType.String(): nil, +}) diff --git a/signal.go b/signal.go index 71d28ec..1d48fbb 100644 --- a/signal.go +++ b/signal.go @@ -16,19 +16,20 @@ import ( type SignalDirection int const ( - StopSignalType SignalType = "STOP" - NewSignalType = "NEW" - StartSignalType = "START" - ErrorSignalType = "ERROR" - StatusSignalType = "STATUS" - LinkSignalType = "LINK" - LockSignalType = "LOCK" - ReadSignalType = "READ" - ReadResultSignalType = "READ_RESULT" - LinkStartSignalType = "LINK_START" - ECDHSignalType = "ECDH" - ECDHProxySignalType = "ECDH_PROXY" - GQLStateSignalType = "GQL_STATE" + StopSignalType = SignalType("STOP") + NewSignalType = SignalType("NEW") + StartSignalType = SignalType("START") + ErrorSignalType = SignalType("ERROR") + StatusSignalType = SignalType("STATUS") + LinkSignalType = SignalType("LINK") + LockSignalType = SignalType("LOCK") + ReadSignalType = SignalType("READ") + ReadResultSignalType = SignalType("READ_RESULT") + LinkStartSignalType = SignalType("LINK_START") + ECDHSignalType = SignalType("ECDH") + ECDHProxySignalType = SignalType("ECDH_PROXY") + GQLStateSignalType = SignalType("GQL_STATE") + ACLTimeoutSignalType = SignalType("ACL_TIMEOUT") Up SignalDirection = iota Down @@ -44,7 +45,8 @@ type Signal interface { String() string Direction() SignalDirection ID() uuid.UUID - Permission() Action + ReqID() uuid.UUID + Permission() Tree } func WaitForSignal[S Signal](ctx * Context, listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { @@ -78,6 +80,11 @@ type BaseSignal struct { SignalDirection SignalDirection `json:"direction"` SignalType SignalType `json:"type"` UUID uuid.UUID `json:"id"` + ReqUUID uuid.UUID `json:"req_uuid"` +} + +func (signal *BaseSignal) ReqID() uuid.UUID { + return signal.ReqUUID } func (signal *BaseSignal) String() string { @@ -97,8 +104,10 @@ func (signal *BaseSignal) Type() SignalType { return signal.SignalType } -func (signal *BaseSignal) Permission() Action { - return MakeAction(signal.Type()) +func (signal *BaseSignal) Permission() Tree { + return Tree{ + string(signal.Type()): Tree{}, + } } func (signal *BaseSignal) Direction() SignalDirection { @@ -110,8 +119,20 @@ func (signal *BaseSignal) Serialize() ([]byte, error) { } func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { + id := uuid.New() + signal := BaseSignal{ + UUID: id, + ReqUUID: id, + SignalDirection: direction, + SignalType: signal_type, + } + return signal +} + +func NewRespSignal(id uuid.UUID, signal_type SignalType, direction SignalDirection) BaseSignal { signal := BaseSignal{ UUID: uuid.New(), + ReqUUID: id, SignalDirection: direction, SignalType: signal_type, } @@ -145,13 +166,8 @@ func (signal *StringSignal) Serialize() ([]byte, error) { return json.Marshal(&signal) } -type RespSignal struct { - BaseSignal - ReqID uuid.UUID -} - type ErrorSignal struct { - RespSignal + BaseSignal Error string } @@ -160,20 +176,18 @@ func (signal *ErrorSignal) String() string { return string(ser) } -func (signal *ErrorSignal) Permission() Action { - return ErrorSignalAction -} - func NewErrorSignal(req_id uuid.UUID, err string) Signal { return &ErrorSignal{ - RespSignal{ - NewBaseSignal(ErrorSignalType, Direct), - req_id, - }, + NewRespSignal(req_id, ErrorSignalType, Direct), err, } } +func NewACLTimeoutSignal(req_id uuid.UUID) Signal { + sig := NewRespSignal(req_id, ACLTimeoutSignalType, Direct) + return &sig +} + type IDStringSignal struct { BaseSignal NodeID NodeID `json:"node_id"` @@ -219,8 +233,12 @@ func NewLockSignal(state string) Signal { } } -func (signal *StringSignal) Permission() Action { - return MakeAction(signal.Type(), signal.Str) +func (signal *StringSignal) Permission() Tree { + return Tree{ + string(signal.Type()): Tree{ + signal.Str: Tree{}, + }, + } } type ReadSignal struct { @@ -239,22 +257,35 @@ func NewReadSignal(exts map[ExtType][]string) *ReadSignal { } } +func (signal *ReadSignal) Permission() Tree { + ret := Tree{} + for ext, fields := range(signal.Extensions) { + field_tree := Tree{} + for _, field := range(fields) { + field_tree[field] = Tree{} + } + ret[ext.String()] = field_tree + } + return Tree{ReadSignalType.String(): ret} +} + type ReadResultSignal struct { - RespSignal - NodeType + BaseSignal + NodeID NodeID + NodeType NodeType Extensions map[ExtType]map[string]interface{} `json:"extensions"` } -func (signal *ReadResultSignal) Permission() Action { - return ReadResultSignalAction +func (signal *ReadResultSignal) Permission() Tree { + return Tree{ + ReadResultSignalType.String(): Tree{}, + } } -func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal { +func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal { return &ReadResultSignal{ - RespSignal{ - NewBaseSignal(ReadResultSignalType, Direct), - req_id, - }, + NewRespSignal(req_id, ReadResultSignalType, Direct), + node_id, node_type, exts, }