Added basic test that loads node with only acl extension

graph-rework-2
noah metz 2023-07-26 00:42:12 -06:00
parent f1c0f1e7de
commit c4156ee146
4 changed files with 68 additions and 17 deletions

@ -17,6 +17,7 @@ type ExtensionInfo struct {
// Information about a loaded node type // Information about a loaded node type
type NodeInfo struct { type NodeInfo struct {
Type NodeType Type NodeType
Extensions []ExtType
} }
// A Context is all the data needed to run a graphvent // A Context is all the data needed to run a graphvent
@ -39,15 +40,31 @@ func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
return ext return ext
} }
func (ctx *Context) RegisterNodeType(node_type NodeType) error { func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
type_hash := node_type.Hash() type_hash := node_type.Hash()
_, exists := ctx.Types[type_hash] _, exists := ctx.Types[type_hash]
if exists == true { if exists == true {
return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type) return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type)
} }
ext_found := map[ExtType]bool{}
for _, extension := range(extensions) {
_, in_ctx := ctx.Extensions[extension.Hash()]
if in_ctx == false {
return fmt.Errorf("Cannot register node type %s, required extension %s not in context", node_type, extension)
}
_, duplicate := ext_found[extension]
if duplicate == true {
return fmt.Errorf("Duplicate extension %s found in extension list", extension)
}
ext_found[extension] = true
}
ctx.Types[type_hash] = NodeInfo{ ctx.Types[type_hash] = NodeInfo{
Type: node_type, Type: node_type,
Extensions: extensions,
} }
return nil return nil
} }

@ -75,13 +75,13 @@ type Node struct {
ID NodeID ID NodeID
Type NodeType Type NodeType
Lock sync.RWMutex Lock sync.RWMutex
ExtensionMap map[ExtType]Extension Extensions map[ExtType]Extension
} }
func GetExt[T Extension](node *Node) (T, error) { func GetExt[T Extension](node *Node) (T, error) {
var zero T var zero T
ext_type := zero.Type() ext_type := zero.Type()
ext, exists := node.ExtensionMap[ext_type] ext, exists := node.Extensions[ext_type]
if exists == false { if exists == false {
return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type) return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
} }
@ -95,17 +95,18 @@ func GetExt[T Extension](node *Node) (T, error) {
} }
func (node *Node) Serialize() ([]byte, error) { func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.ExtensionMap)) extensions := make([]ExtensionDB, len(node.Extensions))
node_db := NodeDB{ node_db := NodeDB{
Header: NodeDBHeader{ Header: NodeDBHeader{
Magic: NODE_DB_MAGIC, Magic: NODE_DB_MAGIC,
TypeHash: node.Type.Hash(),
NumExtensions: uint32(len(extensions)), NumExtensions: uint32(len(extensions)),
}, },
Extensions: extensions, Extensions: extensions,
} }
i := 0 i := 0
for ext_type, info := range(node.ExtensionMap) { for ext_type, info := range(node.Extensions) {
ser, err := info.Serialize() ser, err := info.Serialize()
if err != nil { if err != nil {
return nil, err return nil, err
@ -127,7 +128,7 @@ func NewNode(id NodeID, node_type NodeType) Node {
return Node{ return Node{
ID: id, ID: id,
Type: node_type, Type: node_type,
ExtensionMap: map[ExtType]Extension{}, Extensions: map[ExtType]Extension{},
} }
} }
@ -137,14 +138,18 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node)
return fmt.Errorf("nil is not allowed to perform any actions") return fmt.Errorf("nil is not allowed to perform any actions")
} }
ext, exists := node.ExtensionMap[ACLExtType] // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension
if exists == false { if principal.ID == node.ID {
return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID) return nil
}
acl_ext, err := GetExt[*ACLExt](node)
if err != nil {
return err
} }
acl_ext := ext.(ACLExt)
for _, policy_node := range(acl_ext.Delegations) { for _, policy_node := range(acl_ext.Delegations) {
ext, exists := policy_node.ExtensionMap[ACLPolicyExtType] ext, exists := policy_node.Extensions[ACLPolicyExtType]
if exists == false { if exists == false {
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt") context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt")
continue continue
@ -168,7 +173,7 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal)
return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node)
}) })
for _, ext := range(node.ExtensionMap) { for _, ext := range(node.Extensions) {
err = ext.Process(context, node, signal) err = ext.Process(context, node, signal)
if err != nil { if err != nil {
return nil return nil
@ -379,7 +384,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
node.ExtensionMap[def.Type] = extension node.Extensions[def.Type] = extension
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash) ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash)
} }

@ -0,0 +1,29 @@
package graphvent
import (
"testing"
)
func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node", "policy"})
node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{})
fatalErr(t, err)
node := NewNode(RandID(), node_type)
node.Extensions[ACLExtType] = &ACLExt{
Delegations: NodeMap{},
}
ctx.Nodes[node.ID] = &node
context := NewWriteContext(ctx)
err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error {
ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %s", ser)
return err
})
fatalErr(t, err)
delete(ctx.Nodes, node.ID)
_, err = LoadNode(ctx, node.ID)
fatalErr(t, err)
}

@ -41,7 +41,7 @@ type ACLExt struct {
Delegations NodeMap Delegations NodeMap
} }
func (ext ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil return nil
} }
@ -60,12 +60,12 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
return nil, err return nil, err
} }
return ACLExt{ return &ACLExt{
Delegations: delegations, Delegations: delegations,
}, nil }, nil
} }
func (ext ACLExt) Serialize() ([]byte, error) { func (ext *ACLExt) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations)) delegations := make([]string, len(ext.Delegations))
i := 0 i := 0
for id, _ := range(ext.Delegations) { for id, _ := range(ext.Delegations) {
@ -81,7 +81,7 @@ func (ext ACLExt) Serialize() ([]byte, error) {
} }
const ACLExtType = ExtType("ACL") const ACLExtType = ExtType("ACL")
func (extension ACLExt) Type() ExtType { func (ext *ACLExt) Type() ExtType {
return ACLExtType return ACLExtType
} }