diff --git a/context.go b/context.go index 21ad0c2..82aecc2 100644 --- a/context.go +++ b/context.go @@ -17,6 +17,7 @@ type ExtensionInfo struct { // Information about a loaded node type type NodeInfo struct { Type NodeType + Extensions []ExtType } // A Context is all the data needed to run a graphvent @@ -39,15 +40,31 @@ func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo { return ext } -func (ctx *Context) RegisterNodeType(node_type NodeType) error { +func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { type_hash := node_type.Hash() _, exists := ctx.Types[type_hash] if exists == true { return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type) } + ext_found := map[ExtType]bool{} + for _, extension := range(extensions) { + _, in_ctx := ctx.Extensions[extension.Hash()] + if in_ctx == false { + return fmt.Errorf("Cannot register node type %s, required extension %s not in context", node_type, extension) + } + + _, duplicate := ext_found[extension] + if duplicate == true { + return fmt.Errorf("Duplicate extension %s found in extension list", extension) + } + + ext_found[extension] = true + } + ctx.Types[type_hash] = NodeInfo{ Type: node_type, + Extensions: extensions, } return nil } diff --git a/node.go b/node.go index 21a58bc..c1e1a58 100644 --- a/node.go +++ b/node.go @@ -75,13 +75,13 @@ type Node struct { ID NodeID Type NodeType Lock sync.RWMutex - ExtensionMap map[ExtType]Extension + Extensions map[ExtType]Extension } func GetExt[T Extension](node *Node) (T, error) { var zero T ext_type := zero.Type() - ext, exists := node.ExtensionMap[ext_type] + ext, exists := node.Extensions[ext_type] if exists == false { return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type) } @@ -95,17 +95,18 @@ func GetExt[T Extension](node *Node) (T, error) { } func (node *Node) Serialize() ([]byte, error) { - extensions := make([]ExtensionDB, len(node.ExtensionMap)) + extensions := make([]ExtensionDB, len(node.Extensions)) node_db := NodeDB{ Header: NodeDBHeader{ Magic: NODE_DB_MAGIC, + TypeHash: node.Type.Hash(), NumExtensions: uint32(len(extensions)), }, Extensions: extensions, } i := 0 - for ext_type, info := range(node.ExtensionMap) { + for ext_type, info := range(node.Extensions) { ser, err := info.Serialize() if err != nil { return nil, err @@ -127,7 +128,7 @@ func NewNode(id NodeID, node_type NodeType) Node { return Node{ ID: id, Type: node_type, - ExtensionMap: map[ExtType]Extension{}, + Extensions: map[ExtType]Extension{}, } } @@ -137,14 +138,18 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node) return fmt.Errorf("nil is not allowed to perform any actions") } - ext, exists := node.ExtensionMap[ACLExtType] - if exists == false { - return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID) + // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension + if principal.ID == node.ID { + return nil + } + + acl_ext, err := GetExt[*ACLExt](node) + if err != nil { + return err } - acl_ext := ext.(ACLExt) for _, policy_node := range(acl_ext.Delegations) { - ext, exists := policy_node.ExtensionMap[ACLPolicyExtType] + ext, exists := policy_node.Extensions[ACLPolicyExtType] if exists == false { context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt") continue @@ -168,7 +173,7 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) }) - for _, ext := range(node.ExtensionMap) { + for _, ext := range(node.Extensions) { err = ext.Process(context, node, signal) if err != nil { return nil @@ -379,7 +384,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { if err != nil { return nil, err } - node.ExtensionMap[def.Type] = extension + node.Extensions[def.Type] = extension ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash) } diff --git a/node_test.go b/node_test.go new file mode 100644 index 0000000..71e7e8e --- /dev/null +++ b/node_test.go @@ -0,0 +1,29 @@ +package graphvent + +import ( + "testing" +) + +func TestNodeDB(t *testing.T) { + ctx := logTestContext(t, []string{"test", "db", "node", "policy"}) + node_type := NodeType("test") + err := ctx.RegisterNodeType(node_type, []ExtType{}) + fatalErr(t, err) + node := NewNode(RandID(), node_type) + node.Extensions[ACLExtType] = &ACLExt{ + Delegations: NodeMap{}, + } + ctx.Nodes[node.ID] = &node + + context := NewWriteContext(ctx) + err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error { + ser, err := node.Serialize() + ctx.Log.Logf("test", "NODE_SER: %s", ser) + return err + }) + fatalErr(t, err) + + delete(ctx.Nodes, node.ID) + _, err = LoadNode(ctx, node.ID) + fatalErr(t, err) +} diff --git a/policy.go b/policy.go index 54aa110..9de62c0 100644 --- a/policy.go +++ b/policy.go @@ -41,7 +41,7 @@ type ACLExt struct { Delegations NodeMap } -func (ext ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { return nil } @@ -60,12 +60,12 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - return ACLExt{ + return &ACLExt{ Delegations: delegations, }, nil } -func (ext ACLExt) Serialize() ([]byte, error) { +func (ext *ACLExt) Serialize() ([]byte, error) { delegations := make([]string, len(ext.Delegations)) i := 0 for id, _ := range(ext.Delegations) { @@ -81,7 +81,7 @@ func (ext ACLExt) Serialize() ([]byte, error) { } const ACLExtType = ExtType("ACL") -func (extension ACLExt) Type() ExtType { +func (ext *ACLExt) Type() ExtType { return ACLExtType }