diff --git a/gql.go b/gql.go index 7d00d26..a2ee777 100644 --- a/gql.go +++ b/gql.go @@ -33,21 +33,6 @@ import ( const GQLThreadType = ThreadType("GQL") const GQLNodeType = NodeType("GQL") -// Initializes a new GQL node without an ACLPolicyExt(which needs to be added) -func NewGQLNode(ctx *Context, gql_ext *GQLExt) (*Node, error) { - node := NewNode(ctx, RandID(), GQLNodeType) - node.Extensions[GroupExtType] = NewGroupExt(nil) - var err error - node.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "init", nil) - if err != nil { - return nil, err - } - node.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) - node.Extensions[GQLExtType] = gql_ext - - return node, nil -} - type AuthReqJSON struct { Time time.Time `json:"time"` Pubkey []byte `json:"pubkey"` diff --git a/gql_test.go b/gql_test.go index d9b862b..831c8c6 100644 --- a/gql_test.go +++ b/gql_test.go @@ -16,19 +16,12 @@ func TestGQL(t *testing.T) { } func TestGQLDB(t * testing.T) { - ctx := logTestContext(t, []string{"test", "signal", "policy"}) + ctx := logTestContext(t, []string{"test", "signal", "policy", "db"}) TestUserNodeType := NodeType("TEST_USER") - err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType}) + err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) fatalErr(t, err) - u1 := NewNode(ctx, RandID(), TestUserNodeType) - u1_policy := NewPerNodePolicy(NodeActions{ - u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, - }) - u1.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ - PerNodePolicyType: &u1_policy, - }) ctx.Log.Logf("test", "U1_ID: %s", u1.ID) @@ -39,35 +32,38 @@ func TestGQLDB(t * testing.T) { err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) - t1 := NewNode(ctx, RandID(), TestThreadNodeType) t1_policy_1 := NewParentOfPolicy(Actions{"signal.abort", "state.write"}) t1_policy_2 := NewPerNodePolicy(NodeActions{ u1.ID: Actions{"parent.write"}, }) - t1.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ - ParentOfPolicyType: &t1_policy_1, - PerNodePolicyType: &t1_policy_2, - }) - t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) + t1_thread, err := NewThreadExt(ctx, BaseThreadType, nil,nil, "init", nil) fatalErr(t, err) - t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) - + t1 := NewNode(ctx, + RandID(), + TestThreadNodeType, + NewACLExt(&t1_policy_1, &t1_policy_2), + t1_thread, + NewLockableExt(nil, nil, nil, nil)) ctx.Log.Logf("test", "T1_ID: %s", t1.ID) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql, err := NewGQLNode(ctx, NewGQLExt(":0", ecdh.P256(), key, nil, nil)) - fatalErr(t, err) - gql_policy_1 := NewChildOfPolicy(Actions{"signal.status"}) - gql_policy_2 := NewPerNodePolicy(NodeActions{ + gql_p1 := NewChildOfPolicy(Actions{"signal.status"}) + gql_p2 := NewPerNodePolicy(NodeActions{ u1.ID: Actions{"parent.write", "children.write", "dependencies.write"}, }) - gql.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ - ChildOfPolicyType: &gql_policy_1, - PerNodePolicyType: &gql_policy_2, - }) + gql_thread, err := NewThreadExt(ctx, GQLThreadType, nil, nil, "init", nil) + fatalErr(t, err) + + gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil) + gql := NewNode(ctx, RandID(), GQLNodeType, + gql_thread, + gql_ext, + NewACLExt(&gql_p1, &gql_p2), + NewGroupExt(nil), + NewLockableExt(nil, nil, nil, nil)) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) info := ParentInfo{true, "start", "restore"} diff --git a/graph_test.go b/graph_test.go index fec1f17..c0ce083 100644 --- a/graph_test.go +++ b/graph_test.go @@ -55,14 +55,14 @@ func (t * GraphTester) CheckForNone(listener chan Signal, str string) { const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { - listener := NewNode(ctx, RandID(), SimpleListenerNodeType) policy := NewAllNodesPolicy([]string{"signal.status", "requirements.write", "requirements.read", "dependencies.write", "dependencies.read", "owner.read", "owner.write"}) listener_extension := NewListenerExt(buffer) - listener.Extensions[ListenerExtType] = listener_extension - listener.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ - AllNodesPolicyType: &policy, - }) - listener.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) + listener := NewNode(ctx, + RandID(), + SimpleListenerNodeType, + listener_extension, + NewACLExt(&policy), + NewLockableExt(nil, nil, nil, nil)) return listener, listener_extension } diff --git a/node.go b/node.go index b8f9927..97503b8 100644 --- a/node.go +++ b/node.go @@ -141,19 +141,40 @@ func (node *Node) Serialize() ([]byte, error) { return node_db.Serialize(), nil } -func NewNode(ctx *Context, id NodeID, node_type NodeType) *Node { +func NewNode(ctx *Context, id NodeID, node_type NodeType, extensions ...Extension) *Node { _, exists := ctx.Nodes[id] if exists == true { panic("Attempted to create an existing node") } + def, exists := ctx.Types[node_type.Hash()] + if exists == false { + panic("Node type %s not registered in Context") + } + + ext_map := map[ExtType]Extension{} + for _, ext := range(extensions) { + _, exists := ext_map[ext.Type()] + if exists == true { + panic("Cannot add the same extension to a node twice") + } + ext_map[ext.Type()] = ext + } + + for _, required_ext := range(def.Extensions) { + _, exists := ext_map[required_ext] + if exists == false { + panic(fmt.Sprintf("%s requires %s", node_type, required_ext)) + } + } + node := &Node{ ID: id, Type: node_type, - Extensions: map[ExtType]Extension{}, + Extensions: ext_map, } - ctx.Nodes[id] = node + return node } @@ -379,8 +400,12 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) } - // Create the blank node with the ID, and add it to the context - node = NewNode(ctx, id, node_type.Type) + node = &Node{ + ID: id, + Type: node_type.Type, + Extensions: map[ExtType]Extension{}, + } + ctx.Nodes[id] = node found_extensions := []ExtType{} // Parse each of the extensions from the db diff --git a/node_test.go b/node_test.go index 588e047..6baab96 100644 --- a/node_test.go +++ b/node_test.go @@ -10,8 +10,7 @@ func TestNodeDB(t *testing.T) { err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) - node := NewNode(ctx, RandID(), node_type) - node.Extensions[GroupExtType] = NewGroupExt(nil) + node := NewNode(ctx, RandID(), node_type, NewGroupExt(nil)) context := NewWriteContext(ctx) err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error { diff --git a/policy.go b/policy.go index 6db7de1..2b13c8a 100644 --- a/policy.go +++ b/policy.go @@ -311,19 +311,19 @@ func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) err return nil } -func NewACLExt(policies map[PolicyType]Policy) *ACLExt { - if policies == nil { - policies = map[PolicyType]Policy{} - } - - for policy_type, policy := range(policies) { - if policy_type != policy.Type() { - panic("POLICY_TYPE_MISMATCH") +func NewACLExt(policies ...Policy) *ACLExt { + policy_map := map[PolicyType]Policy{} + for _, policy := range(policies) { + _, exists := policy_map[policy.Type()] + if exists == true { + panic("Cannot add same policy type twice") } + + policy_map[policy.Type()] = policy } return &ACLExt{ - Policies: policies, + Policies: policy_map, } } @@ -336,7 +336,8 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - policies := map[PolicyType]Policy{} + policies := make([]Policy, len(j.Policies)) + i := 0 acl_ctx := ctx.ExtByType(ACLExtType).Data.(*ACLExtContext) for name, ser := range(j.Policies) { policy_def, exists := acl_ctx.Types[PolicyType(name)] @@ -348,13 +349,14 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - policies[PolicyType(name)] = policy + policies[i] = policy + i++ } - return NewACLExt(policies), nil + return NewACLExt(policies...), nil } -const ACLExtType = ExtType("ACL_POLICIES") +const ACLExtType = ExtType("ACL") func (ext *ACLExt) Type() ExtType { return ACLExtType }