From ff813d6c2b20810ce85c5dccefd76c60b9d9ad6e Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Tue, 25 Jul 2023 21:43:15 -0600 Subject: [PATCH] Moved from inheritance to extensions --- context.go | 232 +++----------------- graph_test.go | 10 +- lockable.go | 574 +++++++++++++++++++++++------------------------- node.go | 597 +++++++++++++++++++++++++++----------------------- policy.go | 371 ++++++------------------------- thread.go | 596 ++++++++++++++++++++++--------------------------- 6 files changed, 966 insertions(+), 1414 deletions(-) diff --git a/context.go b/context.go index 071b31f..b4d6d8a 100644 --- a/context.go +++ b/context.go @@ -1,31 +1,15 @@ package graphvent import ( - "github.com/graphql-go/graphql" badger "github.com/dgraph-io/badger/v3" - "reflect" "fmt" ) -// NodeLoadFunc is the footprint of the function used to create a new node in memory from persisted bytes -type NodeLoadFunc func(*Context, NodeID, []byte, NodeMap)(Node, error) - -// A NodeDef is a description of a node that can be added to a Context -type NodeDef struct { - Load NodeLoadFunc - Type NodeType - GQLType *graphql.Object - Reflect reflect.Type -} - -// Create a new Node def, extracting the Type and Reflect from example -func NewNodeDef(example Node, load_func NodeLoadFunc, gql_type *graphql.Object) NodeDef { - return NodeDef{ - Type: example.Type(), - Load: load_func, - GQLType: gql_type, - Reflect: reflect.TypeOf(example), - } +type ExtensionLoadFunc func(*Context, []byte) (Extension, error) +type ExtensionInfo struct { + Load ExtensionLoadFunc + Type ExtType + Data interface{} } // A Context is all the data needed to run a graphvent @@ -34,211 +18,55 @@ type Context struct { DB * badger.DB // Log is an interface used to record events happening Log Logger - // A mapping between type hashes and their corresponding node definitions - Types map[uint64]NodeDef - // GQL substructure - GQL GQLContext -} - -// Recreate the GQL schema after making changes -func (ctx * Context) RebuildSchema() error { - schemaConfig := graphql.SchemaConfig{ - Types: ctx.GQL.TypeList, - Query: ctx.GQL.Query, - Mutation: ctx.GQL.Mutation, - Subscription: ctx.GQL.Subscription, - } - - schema, err := graphql.NewSchema(schemaConfig) - if err != nil { - return err - } - - ctx.GQL.Schema = schema - return nil + // A mapping between type hashes and their corresponding extension definitions + Extensions map[uint64]ExtensionInfo + // All loaded Nodes + Nodes map[NodeID]*Node } -// Add a non-node type to the gql context -func (ctx * Context) AddGQLType(gql_type graphql.Type) { - ctx.GQL.TypeList = append(ctx.GQL.TypeList, gql_type) +func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo { + type_hash := ext_type.Hash() + ext, _ := ctx.Extensions[type_hash] + return ext } // Add a node to a context, returns an error if the def is invalid or already exists in the context -func (ctx * Context) RegisterNodeType(def NodeDef) error { - if def.Load == nil { - return fmt.Errorf("Cannot register a node without a load function: %s", def.Type) - } - - if def.Reflect == nil { - return fmt.Errorf("Cannot register a node without a reflect type: %s", def.Type) - } - - if def.GQLType == nil { - return fmt.Errorf("Cannot register a node without a gql type: %s", def.Type) +func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error { + if load_fn == nil { + return fmt.Errorf("def has no load function") } - type_hash := def.Type.Hash() - _, exists := ctx.Types[type_hash] + type_hash := ext_type.Hash() + _, exists := ctx.Extensions[type_hash] if exists == true { - return fmt.Errorf("Cannot register node of type %s, type already exists in context", def.Type) + return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type) } - ctx.Types[type_hash] = def - - node_type := reflect.TypeOf((*Node)(nil)).Elem() - lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem() - thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem() - - if def.Reflect.Implements(node_type) { - ctx.GQL.ValidNodes[def.Reflect] = def.GQLType - } - if def.Reflect.Implements(lockable_type) { - ctx.GQL.ValidLockables[def.Reflect] = def.GQLType - } - if def.Reflect.Implements(thread_type) { - ctx.GQL.ValidThreads[def.Reflect] = def.GQLType + ctx.Extensions[type_hash] = ExtensionInfo{ + Load: load_fn, + Type: ext_type, } - ctx.GQL.TypeList = append(ctx.GQL.TypeList, def.GQLType) - return nil } -// Map of go types to graphql types -type ObjTypeMap map[reflect.Type]*graphql.Object - -// GQL Specific Context information -type GQLContext struct { - // Generated GQL schema - Schema graphql.Schema - - // List of GQL types - TypeList []graphql.Type - - // Interface type maps to map go types of specific interfaces to gql types - ValidNodes ObjTypeMap - ValidLockables ObjTypeMap - ValidThreads ObjTypeMap - - BaseNodeType *graphql.Object - BaseLockableType *graphql.Object - BaseThreadType *graphql.Object - - Query *graphql.Object - Mutation *graphql.Object - Subscription *graphql.Object -} - -// Create a new GQL context without any content -func NewGQLContext() GQLContext { - query := graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: graphql.Fields{}, - }) - - mutation := graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: graphql.Fields{}, - }) - - subscription := graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: graphql.Fields{}, - }) - - ctx := GQLContext{ - Schema: graphql.Schema{}, - TypeList: []graphql.Type{}, - ValidNodes: ObjTypeMap{}, - ValidThreads: ObjTypeMap{}, - ValidLockables: ObjTypeMap{}, - Query: query, - Mutation: mutation, - Subscription: subscription, - BaseNodeType: GQLTypeSimpleNode.Type, - BaseLockableType: GQLTypeSimpleLockable.Type, - BaseThreadType: GQLTypeSimpleThread.Type, - } - - return ctx -} - // Create a new Context with all the library content added -func NewContext(db * badger.DB, log Logger) * Context { +func NewContext(db * badger.DB, log Logger) (*Context, error) { ctx := &Context{ - GQL: NewGQLContext(), DB: db, Log: log, - Types: map[uint64]NodeDef{}, + Extensions: map[uint64]ExtensionInfo{}, + Nodes: map[NodeID]*Node{}, } - err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*GQLThread)(nil), LoadGQLThread, GQLTypeGQLThread.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*User)(nil), LoadUser, GQLTypeUser.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*Group)(nil), LoadGroup, GQLTypeSimpleLockable.Type)) + err := ctx.RegisterExtension(ACLExtType, LoadACLExtension) if err != nil { - panic(err) + return nil, err } - err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*ParentPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*ChildrenPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - err = ctx.RegisterNodeType(NewNodeDef((*UserOfPolicy)(nil), LoadUserOfPolicy, GQLTypeSimpleNode.Type)) - if err != nil { - panic(err) - } - - ctx.AddGQLType(GQLTypeSignal.Type) - - ctx.GQL.Query.AddFieldConfig("Self", GQLQuerySelf) - ctx.GQL.Query.AddFieldConfig("User", GQLQueryUser) - - ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate) - ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf) - - ctx.GQL.Mutation.AddFieldConfig("abort", GQLMutationAbort) - ctx.GQL.Mutation.AddFieldConfig("startChild", GQLMutationStartChild) - err = ctx.RebuildSchema() + err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension) if err != nil { - panic(err) + return nil, err } - return ctx + return ctx, nil } diff --git a/graph_test.go b/graph_test.go index e222dec..258e64f 100644 --- a/graph_test.go +++ b/graph_test.go @@ -52,13 +52,15 @@ func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { } } -func logTestContext(t * testing.T, components []string) * Context { +func logTestContext(t * testing.T, components []string) *Context { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) if err != nil { t.Fatal(err) } - return NewContext(db, NewConsoleLogger(components)) + ctx, err := NewContext(db, NewConsoleLogger(components)) + fatalErr(t, err) + return ctx } func testContext(t * testing.T) * Context { @@ -67,7 +69,9 @@ func testContext(t * testing.T) * Context { t.Fatal(err) } - return NewContext(db, NewConsoleLogger([]string{})) + ctx, err := NewContext(db, NewConsoleLogger([]string{})) + fatalErr(t, err) + return ctx } func fatalErr(t * testing.T, err error) { diff --git a/lockable.go b/lockable.go index 9f54216..6a412c8 100644 --- a/lockable.go +++ b/lockable.go @@ -2,171 +2,121 @@ package graphvent import ( "fmt" - "reflect" "encoding/json" ) -type Listener struct { - Lockable +type ListenerExt struct { Chan chan GraphSignal } -func (node *Listener) Type() NodeType { - return NodeType("listener") -} - -func (node *Listener) Process(context *StateContext, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s", node.ID()) - select { - case node.Chan <- signal: - default: - return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal) +func NewListenerExt(buffer int) ListenerExt { + return ListenerExt{ + Chan: make(chan GraphSignal, buffer), } - return node.Lockable.Process(context, signal) -} - -const LISTENER_CHANNEL_BUFFER = 1024 -func NewListener(id NodeID, name string) Listener { - return Listener{ - Lockable: NewLockable(id, name), - Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER), - } -} - -var LoadListener = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { - listener := NewListener(id, j.Name) - return &listener, nil -}, RestoreLockable) - -type LockableNode interface { - Node - LockableHandle() *Lockable } -// Lockable is a simple Lockable implementation that can be embedded into more complex structures -type Lockable struct { - SimpleNode - Name string - Owner LockableNode - Requirements map[NodeID]LockableNode - Dependencies map[NodeID]LockableNode - LocksHeld map[NodeID]LockableNode +const ListenerExtType = ExtType("LISTENER") +func (listener ListenerExt) Type() ExtType { + return ListenerExtType } -func (lockable *Lockable) LockableHandle() *Lockable { - return lockable +func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error { + select { + case ext.Chan <- signal: + default: + return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal) + } + return nil } -func (lockable *Lockable) Type() NodeType { - return NodeType("lockable") +func (node ListenerExt) Serialize() ([]byte, error) { + return []byte{}, nil } -type LockableJSON struct { - SimpleNodeJSON - Name string `json:"name"` - Owner string `json:"owner"` - Dependencies []string `json:"dependencies"` - Requirements []string `json:"requirements"` - LocksHeld map[string]string `json:"locks_held"` +type LockableExt struct { + Owner *Node + Requirements map[NodeID]*Node + Dependencies map[NodeID]*Node + LocksHeld map[NodeID]*Node } -func (lockable *Lockable) Serialize() ([]byte, error) { - lockable_json := NewLockableJSON(lockable) - return json.MarshalIndent(&lockable_json, "", " ") +const LockableExtType = ExtType("LOCKABLE") +func (ext *LockableExt) Type() ExtType { + return LockableExtType } -func NewLockableJSON(lockable *Lockable) LockableJSON { - requirement_ids := make([]string, len(lockable.Requirements)) +func (ext *LockableExt) Serialize() ([]byte, error) { + requirements := make([]string, len(ext.Requirements)) req_n := 0 - for id, _ := range(lockable.Requirements) { - requirement_ids[req_n] = id.String() + for id, _ := range(ext.Requirements) { + requirements[req_n] = id.String() req_n++ } - dependency_ids := make([]string, len(lockable.Dependencies)) + dependencies := make([]string, len(ext.Dependencies)) dep_n := 0 - for id, _ := range(lockable.Dependencies) { - dependency_ids[dep_n] = id.String() + for id, _ := range(ext.Dependencies) { + dependencies[dep_n] = id.String() dep_n++ } - owner_id := "" - if lockable.Owner != nil { - owner_id = lockable.Owner.ID().String() + owner := "" + if ext.Owner != nil { + owner = ext.Owner.ID.String() } locks_held := map[string]string{} - for lockable_id, node := range(lockable.LocksHeld) { + for lockable_id, node := range(ext.LocksHeld) { if node == nil { locks_held[lockable_id.String()] = "" } else { - locks_held[lockable_id.String()] = node.ID().String() + locks_held[lockable_id.String()] = node.ID.String() } } - node_json := NewSimpleNodeJSON(&lockable.SimpleNode) - - return LockableJSON{ - SimpleNodeJSON: node_json, - Name: lockable.Name, - Owner: owner_id, - Dependencies: dependency_ids, - Requirements: requirement_ids, + return json.MarshalIndent(&struct{ + Owner string `json:"owner"` + Requirements []string `json:"requirements"` + Dependencies []string `json:"dependencies"` + LocksHeld map[string]string `json:"locks_held"` + }{ + Owner: owner, + Requirements: requirements, + Dependencies: dependencies, LocksHeld: locks_held, - } + }, "", " ") } -func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode { - lockable_id := l.ID() - last_owner, exists := lockable.LocksHeld[lockable_id] - if exists == false { - panic("Attempted to take a get the original lock holder of a lockable we don't own") - } - delete(lockable.LocksHeld, lockable_id) - return last_owner -} - -func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) { - lockable_id := l.ID() - _, exists := lockable.LocksHeld[lockable_id] - if exists == true { - panic("Attempted to lock a lockable we're already holding(lock cycle)") - } - - lockable.LocksHeld[lockable_id] = last_owner -} - -// Assumed that lockable is already locked for signal -func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", lockable.ID()) +func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) var err error switch signal.Direction() { case Up: - err = UseStates(context, lockable, - NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error { + err = UseStates(context, node, + NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error { owner_sent := false - for _, dependency := range(lockable.Dependencies) { - context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) - Signal(context, dependency, lockable, signal) - if lockable.Owner != nil { - if dependency.ID() == lockable.Owner.ID() { + for _, dependency := range(ext.Dependencies) { + context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) + Signal(context, dependency, node, signal) + if ext.Owner != nil { + if dependency.ID == ext.Owner.ID { owner_sent = true } } } - if lockable.Owner != nil && owner_sent == false { - if lockable.Owner.ID() != lockable.ID() { - context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID()) - return Signal(context, lockable.Owner, lockable, signal) + if ext.Owner != nil && owner_sent == false { + if ext.Owner.ID != node.ID { + context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) + return Signal(context, ext.Owner, node, signal) } } return nil }) case Down: - err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { - for _, requirement := range(lockable.Requirements) { - err := Signal(context, requirement, lockable, signal) + err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { + for _, requirement := range(ext.Requirements) { + err := Signal(context, requirement, node, signal) if err != nil { return err } @@ -176,112 +126,154 @@ func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) err case Direct: err = nil default: - return fmt.Errorf("invalid signal direction %d", signal.Direction()) + err = fmt.Errorf("invalid signal direction %d", signal.Direction()) } if err != nil { return err } - return lockable.SimpleNode.Process(context, signal) + return nil +} + +func (ext *LockableExt) RecordUnlock(node *Node) *Node { + last_owner, exists := ext.LocksHeld[node.ID] + if exists == false { + panic("Attempted to take a get the original lock holder of a lockable we don't own") + } + delete(ext.LocksHeld, node.ID) + return last_owner +} + +func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) { + _, exists := ext.LocksHeld[node.ID] + if exists == true { + panic("Attempted to lock a lockable we're already holding(lock cycle)") + } + ext.LocksHeld[node.ID] = last_owner } // Removes requirement as a requirement from lockable -// Continues the write context with princ, getting requirents for lockable and dependencies for requirement -// Assumes that an active write context exists with princ locked so that princ's state can be used in checks -func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error { - return UpdateStates(context, princ, LockMap{ - lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}}, - requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}}, +func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error { + lockable_ext, err := GetExt[*LockableExt](lockable) + if err != nil { + return err + } + requirement_ext, err := GetExt[*LockableExt](requirement) + if err != nil { + return err + } + return UpdateStates(context, princ, ACLMap{ + lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}}, + requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}}, }, func(context *StateContext) error { - var found Node = nil - for _, req := range(lockable.LockableHandle().Requirements) { - if requirement.ID() == req.ID() { + var found *Node = nil + for _, req := range(lockable_ext.Requirements) { + if requirement.ID == req.ID { found = req break } } if found == nil { - return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) + return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID) } - delete(requirement.LockableHandle().Dependencies, lockable.ID()) - delete(lockable.LockableHandle().Requirements, requirement.ID()) + delete(requirement_ext.Dependencies, lockable.ID) + delete(lockable_ext.Requirements, requirement.ID) return nil }) } // Link requirements as requirements to lockable -// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements -func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error { - if lockable_node == nil { +func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error { + if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } - lockable := lockable_node.LockableHandle() if len(requirements) == 0 { - return nil + return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call") + } + + lockable_ext, err := GetExt[*LockableExt](lockable) + if err != nil { + return err } - found := map[NodeID]bool{} + req_exts := map[NodeID]*LockableExt{} for _, requirement := range(requirements) { if requirement == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement") } - if lockable.ID() == requirement.ID() { - return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID()) + if lockable.ID == requirement.ID { + return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID) } - _, exists := found[requirement.ID()] + _, exists := req_exts[requirement.ID] if exists == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID()) + return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID) + } + ext, err := GetExt[*LockableExt](requirement) + if err != nil { + return err } - found[requirement.ID()] = true + req_exts[requirement.ID] = ext } - return UpdateStates(context, princ, NewLockMap( - NewLockInfo(lockable_node, []string{"requirements"}), - LockList(requirements, []string{"dependencies"}), + return UpdateStates(context, princ, NewACLMap( + NewACLInfo(lockable, []string{"requirements"}), + ACLList(requirements, []string{"dependencies"}), ), func(context *StateContext) error { // Check that all the requirements can be added // If the lockable is already locked, need to lock this resource as well before we can add it - for _, requirement_node := range(requirements) { - requirement := requirement_node.LockableHandle() - for _, req_node := range(requirements) { - req := req_node.LockableHandle() - if req.ID() == requirement.ID() { + for _, requirement := range(requirements) { + requirement_ext := req_exts[requirement.ID] + for _, req := range(requirements) { + if req.ID == requirement.ID { continue } - if checkIfRequirement(context, req, requirement) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + + is_req, err := checkIfRequirement(context, req.ID, requirement_ext) + if err != nil { + return err + } else if is_req { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID) + } } - if checkIfRequirement(context, lockable, requirement) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) + + is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext) + if err != nil { + return err + } else if is_req { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID) } - if checkIfRequirement(context, requirement, lockable) == true { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext) + if err != nil { + return err + } else if is_req { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID) } - if lockable.Owner == nil { + + if lockable_ext.Owner == nil { // If the new owner isn't locked, we can add the requirement - } else if requirement.Owner == nil { + } else if requirement_ext.Owner == nil { // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID) } else { // If the new requirement is already locked and the owner is already locked, their owners need to match - if requirement.Owner.ID() != lockable.Owner.ID() { - return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + if requirement_ext.Owner.ID != lockable_ext.Owner.ID { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID) } } } // Update the states of the requirements - for _, requirement_node := range(requirements) { - requirement := requirement_node.LockableHandle() - requirement.Dependencies[lockable.ID()] = lockable_node - lockable.Requirements[lockable.ID()] = requirement_node - context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) + for _, requirement := range(requirements) { + requirement_ext := req_exts[requirement.ID] + requirement_ext.Dependencies[lockable.ID] = lockable + lockable_ext.Requirements[lockable.ID] = requirement + context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID) } // Return no error @@ -289,74 +281,91 @@ func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode }) } -// Must be called withing update context -func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool { - for _, c := range(cur.LockableHandle().Requirements) { - if c.ID() == r.ID() { - return true +func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) { + for _, req := range(cur.Requirements) { + if req.ID == id { + return true, nil } - is_requirement := false - UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error { - is_requirement = checkIfRequirement(context, cur, c) - return nil - }) - if is_requirement { - return true + req_ext, err := GetExt[*LockableExt](req) + if err != nil { + return false, err + } + + var is_req bool + err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error { + is_req, err = checkIfRequirement(context, id, req_ext) + return err + }) + if err != nil { + return false, err + } + if is_req == true { + return true, nil } } - return false + return false, nil } // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call -func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error { +func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error { if to_lock == nil { - return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") + return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided") } + req_exts := map[NodeID]*LockableExt{} for _, l := range(to_lock) { + var err error if l == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") } + + req_exts[l.ID], err = GetExt[*LockableExt](l) + if err != nil { + return err + } } - if new_owner_node == nil { + if new_owner == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") } - new_owner := new_owner_node.LockableHandle() + new_owner_ext, err := GetExt[*LockableExt](new_owner) + if err != nil { + return err + } // Called with no requirements to lock, success if len(to_lock) == 0 { return nil } - return UpdateStates(context, new_owner, NewLockMap( - LockListM(to_lock, []string{"lock"}), - NewLockInfo(new_owner, nil), + return UpdateStates(context, new_owner, NewACLMap( + ACLListM(to_lock, []string{"lock"}), + NewACLInfo(new_owner, nil), ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements - for _, req_node := range(to_lock) { - req := req_node.LockableHandle() - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) + for _, req := range(to_lock) { + req_ext := req_exts[req.ID] + context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID) // If req is alreay locked, check that we can pass the lock - if req.Owner != nil { - owner := req.Owner - if owner.ID() == new_owner.ID() { + if req_ext.Owner != nil { + owner := req_ext.Owner + if owner.ID == new_owner.ID { continue } else { - err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ - return LockLockables(context, req.Requirements, req) + err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ + return LockLockables(context, req_ext.Requirements, req) }) if err != nil { return err } } } else { - err := LockLockables(context, req.Requirements, req) + err := LockLockables(context, req_ext.Requirements, req) if err != nil { return err } @@ -364,22 +373,22 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o } // At this point state modification will be started, so no errors can be returned - for _, req_node := range(to_lock) { - req := req_node.LockableHandle() - old_owner := req.Owner + for _, req := range(to_lock) { + req_ext := req_exts[req.ID] + old_owner := req_ext.Owner // If the lockable was previously unowned, update the state if old_owner == nil { - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) - req.Owner = new_owner_node - new_owner.RecordLock(req, old_owner) + context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID) + req_ext.Owner = new_owner + new_owner_ext.RecordLock(req, old_owner) // Otherwise if the new owner already owns it, no need to update state - } else if old_owner.ID() == new_owner.ID() { - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) + } else if old_owner.ID == new_owner.ID { + context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID) // Otherwise update the state } else { - req.Owner = new_owner - new_owner.RecordLock(req, old_owner) - context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) + req_ext.Owner = new_owner + new_owner_ext.RecordLock(req, old_owner) + context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID) } } return nil @@ -387,61 +396,72 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o } -func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error { +func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } + req_exts := map[NodeID]*LockableExt{} for _, l := range(to_unlock) { if l == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil") } + + var err error + req_exts[l.ID], err = GetExt[*LockableExt](l) + if err != nil { + return err + } } - if old_owner_node == nil { + if old_owner == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") } - old_owner := old_owner_node.LockableHandle() + old_owner_ext, err := GetExt[*LockableExt](old_owner) + if err != nil { + return err + } + // Called with no requirements to unlock, success if len(to_unlock) == 0 { return nil } - return UpdateStates(context, old_owner, NewLockMap( - LockListM(to_unlock, []string{"lock"}), - NewLockInfo(old_owner, nil), + return UpdateStates(context, old_owner, NewACLMap( + ACLListM(to_unlock, []string{"lock"}), + NewACLInfo(old_owner, nil), ), func(context *StateContext) error { // First loop is to check that the states can be locked, and locks all requirements - for _, req_node := range(to_unlock) { - req := req_node.LockableHandle() - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) + for _, req := range(to_unlock) { + req_ext := req_exts[req.ID] + context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID) // Check if the owner is correct - if req.Owner != nil { - if req.Owner.ID() != old_owner.ID() { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) + if req_ext.Owner != nil { + if req_ext.Owner.ID != old_owner.ID { + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID) } } else { - return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID) } - err := UnlockLockables(context, req.Requirements, req) + err := UnlockLockables(context, req_ext.Requirements, req) if err != nil { return err } } // At this point state modification will be started, so no errors can be returned - for _, req_node := range(to_unlock) { - req := req_node.LockableHandle() - new_owner := old_owner.RecordUnlock(req) - req.Owner = new_owner + for _, req := range(to_unlock) { + req_ext := req_exts[req.ID] + new_owner := old_owner_ext.RecordUnlock(req) + req_ext.Owner = new_owner if new_owner == nil { - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) + context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID) } else { - context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) + context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID) } } @@ -449,103 +469,55 @@ func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, o }) } -var LoadLockable = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { - lockable := NewLockable(id, j.Name) - return &lockable, nil -}, RestoreLockable) - -func NewLockable(id NodeID, name string) Lockable { - return Lockable{ - SimpleNode: NewSimpleNode(id), - Name: name, - Owner: nil, - Requirements: map[NodeID]LockableNode{}, - Dependencies: map[NodeID]LockableNode{}, - LocksHeld: map[NodeID]LockableNode{}, +func RestoreNode(ctx *Context, id_str string) (*Node, error) { + id, err := ParseID(id_str) + if err != nil { + return nil, err } + + return LoadNode(ctx, id) } -// Helper function to load links when loading a struct that embeds Lockable -func RestoreLockable(ctx * Context, lockable LockableNode, j LockableJSON, nodes NodeMap) error { - lockable_ptr := lockable.LockableHandle() - if j.Owner != "" { - owner_id, err := ParseID(j.Owner) - if err != nil { - return err - } - owner_node, err := LoadNodeRecurse(ctx, owner_id, nodes) +func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) { + nodes := NodeMap{} + for id_str_1, id_str_2 := range(ids) { + id_1, err := ParseID(id_str_1) if err != nil { - return err - } - owner, ok := owner_node.(LockableNode) - if ok == false { - return fmt.Errorf("%s is not a Lockable", j.Owner) + return nil, err } - lockable_ptr.Owner = owner - } - for _, dep_str := range(j.Dependencies) { - dep_id, err := ParseID(dep_str) - if err != nil { - return err - } - dep_node, err := LoadNodeRecurse(ctx, dep_id, nodes) + id_2, err := ParseID(id_str_2) if err != nil { - return err - } - dep, ok := dep_node.(LockableNode) - if ok == false { - return fmt.Errorf("%+v is not a Lockable as expected", dep_node) + return nil, err } - ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep)) - lockable_ptr.Dependencies[dep_id] = dep - } - for _, req_str := range(j.Requirements) { - req_id, err := ParseID(req_str) + node_1, err := LoadNode(ctx, id_1) if err != nil { - return err + return nil, err } - req_node, err := LoadNodeRecurse(ctx, req_id, nodes) + + node_2, err := LoadNode(ctx, id_2) if err != nil { - return err + return nil, err } - req, ok := req_node.(LockableNode) - if ok == false { - return fmt.Errorf("%+v is not a Lockable as expected", req_node) - } - lockable_ptr.Requirements[req_id] = req + + nodes[node_1.ID] = node_2 } - for l_id_str, h_str := range(j.LocksHeld) { - l_id, err := ParseID(l_id_str) - l, err := LoadNodeRecurse(ctx, l_id, nodes) - if err != nil { - return err - } - l_l, ok := l.(LockableNode) - if ok == false { - return fmt.Errorf("%s is not a Lockable", l.ID()) - } + return nodes, nil +} - var h_l LockableNode - if h_str != "" { - h_id, err := ParseID(h_str) - if err != nil { - return err - } - h_node, err := LoadNodeRecurse(ctx, h_id, nodes) - if err != nil { - return err - } - h, ok := h_node.(LockableNode) - if ok == false { - return err - } - h_l = h +func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) { + nodes := NodeMap{} + + for _, id_str := range(ids) { + node, err := RestoreNode(ctx, id_str) + if err != nil { + return nil, err } - lockable_ptr.RecordLock(l_l, h_l) + nodes[node.ID] = node } - return RestoreSimpleNode(ctx, lockable, j.SimpleNodeJSON, nodes) + return nodes, nil } + diff --git a/node.go b/node.go index 3671ab6..9ed503e 100644 --- a/node.go +++ b/node.go @@ -2,6 +2,7 @@ package graphvent import ( "sync" + "reflect" "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" "fmt" @@ -31,6 +32,12 @@ func (id NodeID) String() string { return (uuid.UUID)(id).String() } +// Ignore the error since we're enforcing 16 byte length at compile time +func IDFromBytes(bytes [16]byte) NodeID { + id, _ := uuid.FromBytes(bytes[:]) + return NodeID(id) +} + func ParseID(str string) (NodeID, error) { id_uuid, err := uuid.Parse(str) if err != nil { @@ -45,230 +52,291 @@ func KeyID(pub *ecdsa.PublicKey) NodeID { return NodeID(str) } -// Types are how nodes are associated with structs at runtime(and from the DB) -type NodeType string -func (node_type NodeType) Hash() uint64 { - hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type))) - - return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) -} - // Generate a random NodeID func RandID() NodeID { return NodeID(uuid.New()) } -type Node interface { - ID() NodeID - Type() NodeType +type Serializable[I comparable] interface { + Type() I Serialize() ([]byte, error) - LockState(write bool) - UnlockState(write bool) - Process(context *StateContext, signal GraphSignal) error - Policies() []Policy - NodeHandle() *SimpleNode } -type SimpleNode struct { - Id NodeID - state_mutex sync.RWMutex - PolicyMap map[NodeID]Policy +// NodeExtensions are additional data that can be attached to nodes, and used in node functions +type Extension interface { + Serializable[ExtType] + // Send a signal to this extension to process, + // this typically triggers signals to be sent to nodes linked in the extension + Process(context *StateContext, node *Node, signal GraphSignal) error } -func (node *SimpleNode) NodeHandle() *SimpleNode { - return node +// Nodes represent an addressible group of extensions +type Node struct { + ID NodeID + Lock sync.RWMutex + ExtensionMap map[ExtType]Extension } -func NewSimpleNode(id NodeID) SimpleNode { - return SimpleNode{ - Id: id, - PolicyMap: map[NodeID]Policy{}, +func GetExt[T Extension](node *Node) (T, error) { + var zero T + ext_type := zero.Type() + ext, exists := node.ExtensionMap[ext_type] + if exists == false { + return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type) } -} - -type SimpleNodeJSON struct { - Policies []string `json:"policies"` -} -func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %s", node.Id, signal) - return nil -} - -func (node *SimpleNode) ID() NodeID { - return node.Id -} - -func (node *SimpleNode) Type() NodeType { - return NodeType("simple_node") -} + ret, ok := ext.(T) + if ok == false { + return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero)) + } -func (node *SimpleNode) Serialize() ([]byte, error) { - j := NewSimpleNodeJSON(node) - return json.MarshalIndent(&j, "", " ") + return ret, nil } -func (node *SimpleNode) LockState(write bool) { - if write == true { - node.state_mutex.Lock() - } else { - node.state_mutex.RLock() - } +// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies +type ACLExtension struct { + Delegations NodeMap } -func (node *SimpleNode) UnlockState(write bool) { - if write == true { - node.state_mutex.Unlock() - } else { - node.state_mutex.RUnlock() - } +func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error { + return nil } -func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON { - policy_ids := make([]string, len(node.PolicyMap)) - i := 0 - for id, _ := range(node.PolicyMap) { - policy_ids[i] = id.String() - i += 1 +func LoadACLExtension(ctx *Context, data []byte) (Extension, error) { + var j struct { + Delegations []string `json:"delegation"` } - return SimpleNodeJSON{ - Policies: policy_ids, + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err } -} -func RestoreSimpleNode(ctx *Context, node Node, j SimpleNodeJSON, nodes NodeMap) error { - node_ptr := node.NodeHandle() - for _, policy_str := range(j.Policies) { - policy_id, err := ParseID(policy_str) + delegations := NodeMap{} + for _, str := range(j.Delegations) { + id, err := ParseID(str) if err != nil { - return err + return nil, err } - policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes) + node, err := LoadNode(ctx, id) if err != nil { - return err + return nil, err } - policy, ok := policy_ptr.(Policy) - if ok == false { - return fmt.Errorf("%s is not a Policy", policy_id) - } - node_ptr.PolicyMap[policy_id] = policy + delegations[id] = node } - return nil + return ACLExtension{ + Delegations: delegations, + }, nil } -func LoadJSONNode[J any, N Node](init_func func(NodeID, J)(Node, error), restore_func func(*Context, N, J, NodeMap)error)func(*Context, NodeID, []byte, NodeMap)(Node, error) { - return func(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j J - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } +func (ext ACLExtension) Serialize() ([]byte, error) { + delegations := make([]string, len(ext.Delegations)) + i := 0 + for id, _ := range(ext.Delegations) { + delegations[i] = id.String() + i += 1 + } - node, err := init_func(id, j) - if err != nil { - return nil, err - } - nodes[id] = node - err = restore_func(ctx, node.(N), j, nodes) - if err != nil { - return nil, err - } + return json.MarshalIndent(&struct{ + Delegations []string `json:"delegations"` + }{ + Delegations: delegations, + }, "", " ") +} - return node, nil - } +const ACLExtType = ExtType("ACL") +func (extension ACLExtension) Type() ExtType { + return ACLExtType } -var LoadSimpleNode = LoadJSONNode(func(id NodeID, j SimpleNodeJSON) (Node, error) { - node := NewSimpleNode(id) - return &node, nil -}, RestoreSimpleNode) +func (node *Node) Serialize() ([]byte, error) { + extensions := make([]ExtensionDB, len(node.ExtensionMap)) + node_db := NodeDB{ + Header: NodeDBHeader{ + Magic: NODE_DB_MAGIC, + NumExtensions: uint32(len(extensions)), + }, + Extensions: extensions, + } -func (node *SimpleNode) Policies() []Policy { - ret := make([]Policy, len(node.PolicyMap)) i := 0 - for _, policy := range(node.PolicyMap) { - ret[i] = policy + for ext_type, info := range(node.ExtensionMap) { + ser, err := info.Serialize() + if err != nil { + return nil, err + } + node_db.Extensions[i] = ExtensionDB{ + Header: ExtensionDBHeader{ + TypeHash: ext_type.Hash(), + Length: uint64(len(ser)), + }, + Data: ser, + } i += 1 } - return ret + return node_db.Serialize(), nil +} + +func NewNode(id NodeID) Node { + return Node{ + ID: id, + ExtensionMap: map[ExtType]Extension{}, + } } -func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error { - if princ == nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) +func Allowed(context *StateContext, principal *Node, action string, node *Node) error { + if principal == nil { + context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s", principal.ID, node.ID, action) return fmt.Errorf("nil is not allowed to perform any actions") } - if node.ID() == princ.ID() { - return nil + + ext, exists := node.ExtensionMap[ACLExtType] + if exists == false { + return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID) } - for _, policy := range(policies) { - if policy.Allows(context, node, resource, action, princ) == true { - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) + acl_ext := ext.(ACLExtension) + + for _, policy_node := range(acl_ext.Delegations) { + ext, exists := policy_node.ExtensionMap[ACLPolicyExtType] + if exists == false { + context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension") + continue + } + policy_ext := ext.(ACLPolicyExtension) + if policy_ext.Allows(context, principal, action, node) == true { + context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action) return nil } } - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) - return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID()) + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", principal.ID, node.ID, action) + return fmt.Errorf("%s is not allowed to perform %s on %s", principal.ID, action, node.ID) } +// Check that princ is allowed to signal this action, +// then send the signal to all the extensions of the node +func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) error { + context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, signal.String()) -// Propagate the signal to registered listeners, if a listener isn't ready to receive the update -// send it a notification that it was closed and then close it -func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String()) - - err := UseStates(context, princ, NewLockInfo(node, []string{}), func(context *StateContext) error { - return Allowed(context, node.Policies(), node, "signal", signal.Type(), princ) + err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error { + return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) }) - if err != nil { - return nil + for _, ext := range(node.ExtensionMap) { + err = ext.Process(context, node, signal) + if err != nil { + return nil + } } - return node.Process(context, signal) -} - -func AttachPolicies(ctx *Context, node Node, policies ...Policy) error { - context := NewWriteContext(ctx) - return UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"policies"}), LockList(policies, nil)), func(context *StateContext) error { - for _, policy := range(policies) { - node.NodeHandle().PolicyMap[policy.ID()] = policy - } - return nil - }) + return nil } // Magic first four bytes of serialized DB content, stored big endian const NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function -const NODE_DB_HEADER_LEN = 12 +const NODE_DB_HEADER_LEN = 8 // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node -type DBHeader struct { +type NodeDBHeader struct { Magic uint32 - TypeHash uint64 + NumExtensions uint32 +} + +type NodeDB struct { + Header NodeDBHeader + Extensions []ExtensionDB +} + +//TODO: add size safety checks +func NewNodeDB(data []byte) (NodeDB, error) { + var zero NodeDB + + ptr := 0 + + magic := binary.BigEndian.Uint32(data[0:4]) + num_extensions := binary.BigEndian.Uint32(data[4:8]) + + ptr += NODE_DB_HEADER_LEN + + if magic != NODE_DB_MAGIC { + return zero, fmt.Errorf("header has incorrect magic 0x%x", magic) + } + + extensions := make([]ExtensionDB, num_extensions) + for i, _ := range(extensions) { + cur := data[ptr:] + + type_hash := binary.BigEndian.Uint64(cur[0:8]) + length := binary.BigEndian.Uint64(cur[8:16]) + + data_start := uint64(EXTENSION_DB_HEADER_LEN) + data_end := data_start + length + ext_data := cur[data_start:data_end] + + extensions[i] = ExtensionDB{ + Header: ExtensionDBHeader{ + TypeHash: type_hash, + Length: length, + }, + Data: ext_data, + } + + ptr += int(EXTENSION_DB_HEADER_LEN + length) + } + + return NodeDB{ + Header: NodeDBHeader{ + Magic: magic, + NumExtensions: num_extensions, + }, + Extensions: extensions, + }, nil } -func (header DBHeader) Serialize() []byte { +func (header NodeDBHeader) Serialize() []byte { if header.Magic != NODE_DB_MAGIC { panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic)) } ret := make([]byte, NODE_DB_HEADER_LEN) binary.BigEndian.PutUint32(ret[0:4], header.Magic) - binary.BigEndian.PutUint64(ret[4:12], header.TypeHash) + binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) return ret } -func NewDBHeader(node_type NodeType) DBHeader { - return DBHeader{ - Magic: NODE_DB_MAGIC, - TypeHash: node_type.Hash(), +func (node NodeDB) Serialize() []byte { + ser := node.Header.Serialize() + for _, extension := range(node.Extensions) { + ser = append(ser, extension.Serialize()...) } + + return ser +} + +func (header ExtensionDBHeader) Serialize() []byte { + ret := make([]byte, EXTENSION_DB_HEADER_LEN) + binary.BigEndian.PutUint64(ret[0:8], header.TypeHash) + binary.BigEndian.PutUint64(ret[8:16], header.Length) + return ret +} + +func (extension ExtensionDB) Serialize() []byte { + header_bytes := extension.Header.Serialize() + return append(header_bytes, extension.Data...) +} + +const EXTENSION_DB_HEADER_LEN = 16 +type ExtensionDBHeader struct { + TypeHash uint64 + Length uint64 +} + +type ExtensionDB struct { + Header ExtensionDBHeader + Data []byte } // Write multiple nodes to the database in a single transaction @@ -283,27 +351,21 @@ func WriteNodes(context *StateContext) error { serialized_bytes := make([][]byte, len(context.Locked)) serialized_ids := make([][]byte, len(context.Locked)) i := 0 - for _, node := range(context.Locked) { + // TODO, just write states from the context, and store the current states in the context + for id, _ := range(context.Locked) { + node, _ := context.Graph.Nodes[id] if node == nil { - return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node") + return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node, maybe node isn't in the context") } + ser, err := node.Serialize() if err != nil { return fmt.Errorf("DB_SERIALIZE_ERROR: %s", err) } - header := NewDBHeader(node.Type()) + id_ser := node.ID.Serialize() - db_data := append(header.Serialize(), ser...) - - context.Graph.Log.Logf("db", "DB_WRITING_TYPE: %s - %+v %+v: %+v", node.ID(), node.Type(), header, node) - if err != nil { - return err - } - - id_ser := node.ID().Serialize() - - serialized_bytes[i] = db_data + serialized_bytes[i] = ser serialized_ids[i] = id_ser i++ @@ -320,8 +382,13 @@ func WriteNodes(context *StateContext) error { }) } -// Get the bytes associates with `id` from the database after unwrapping the header, or error -func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) { +// Recursively load a node from the database. +func LoadNode(ctx * Context, id NodeID) (*Node, error) { + node, exists := ctx.Nodes[id] + if exists == true { + return node,nil + } + var bytes []byte err := ctx.DB.View(func(txn *badger.Txn) error { item, err := txn.Get(id.Serialize()) @@ -334,80 +401,51 @@ func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) { return nil }) }) - if err != nil { - ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err) - return 0, nil, err - } - - if len(bytes) < NODE_DB_HEADER_LEN { - return 0, nil, fmt.Errorf("header for %s is %d/%d bytes", id, len(bytes), NODE_DB_HEADER_LEN) + return nil, err } - header := DBHeader{} - header.Magic = binary.BigEndian.Uint32(bytes[0:4]) - header.TypeHash = binary.BigEndian.Uint64(bytes[4:12]) - - if header.Magic != NODE_DB_MAGIC { - return 0, nil, fmt.Errorf("header for %s, invalid magic 0x%x", id, header.Magic) + // Parse the bytes from the DB + node_db, err := NewNodeDB(bytes) + if err != nil { + return nil, err } - node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN) - copy(node_bytes, bytes[NODE_DB_HEADER_LEN:]) - - ctx.Log.Logf("db", "DB_READ: %s %+v - %s", id, header, string(bytes)) - - return header.TypeHash, node_bytes, nil -} + // Create the blank node with the ID, and add it to the context + new_node := NewNode(id) + node = &new_node + ctx.Nodes[id] = node -// Load a Node from the database by ID -func LoadNode(ctx * Context, id NodeID) (Node, error) { - nodes := NodeMap{} - return LoadNodeRecurse(ctx, id, nodes) -} - - -// Recursively load a node from the database. -// It's expected that node_type.Load adds the newly loaded node to nodes before calling LoadNodeRecurse again. -func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) { - node, exists := nodes[id] - if exists == false { - type_hash, bytes, err := readNodeBytes(ctx, id) - if err != nil { - return nil, err + // Parse each of the extensions from the db + for _, ext_db := range(node_db.Extensions) { + type_hash := ext_db.Header.TypeHash + def, known := ctx.Extensions[type_hash] + if known == false { + return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash) } - - node_type, exists := ctx.Types[type_hash] - ctx.Log.Logf("db", "DB_LOADING_TYPE: %s - %+v", id, node_type) - if exists == false { - return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes) - } - - if node_type.Load == nil { - return nil, fmt.Errorf("0x%x is an invalid node type, nil Load", type_hash) - } - - node, err = node_type.Load(ctx, id, bytes, nodes) + extension, err := def.Load(ctx, ext_db.Data) if err != nil { return nil, err } - - ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) + node.ExtensionMap[def.Type] = extension + ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash) } + + ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) return node, nil } -func NewLockInfo(node Node, resources []string) LockMap { - return LockMap{ - node.ID(): LockInfo{ +func NewACLInfo(node *Node, resources []string) ACLMap { + return ACLMap{ + node.ID: ACLInfo{ Node: node, Resources: resources, }, } } -func NewLockMap(requests ...LockMap) LockMap { - reqs := LockMap{} +func NewACLMap(requests ...ACLMap) ACLMap { + reqs := ACLMap{} for _, req := range(requests) { for id, info := range(req) { reqs[id] = info @@ -416,10 +454,10 @@ func NewLockMap(requests ...LockMap) LockMap { return reqs } -func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap { - reqs := LockMap{} +func ACLListM(m map[NodeID]*Node, resources[]string) ACLMap { + reqs := ACLMap{} for _, node := range(m) { - reqs[node.ID()] = LockInfo{ + reqs[node.ID] = ACLInfo{ Node: node, Resources: resources, } @@ -427,10 +465,10 @@ func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap { return reqs } -func LockList[K Node](list []K, resources []string) LockMap { - reqs := LockMap{} +func ACLList(list []*Node, resources []string) ACLMap { + reqs := ACLMap{} for _, node := range(list) { - reqs[node.ID()] = LockInfo{ + reqs[node.ID] = ACLInfo{ Node: node, Resources: resources, } @@ -438,21 +476,40 @@ func LockList[K Node](list []K, resources []string) LockMap { return reqs } +type PolicyType string +func (policy PolicyType) Hash() uint64 { + hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy)))) + return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) +} + +type ExtType string +func (ext ExtType) Hash() uint64 { + hash := sha512.Sum512([]byte(fmt.Sprintf("EXTENSION: %s", string(ext)))) + return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) +} -type NodeMap map[NodeID]Node +type NodeMap map[NodeID]*Node -type LockInfo struct { - Node Node +type ACLInfo struct { + Node *Node Resources []string } -type LockMap map[NodeID]LockInfo +type ACLMap map[NodeID]ACLInfo +type ExtMap map[uint64]Extension +// Context of running state usage(read/write) type StateContext struct { + // Type of the state context Type string + // The wrapped graph context Graph *Context - Permissions map[NodeID]LockMap - Locked NodeMap + // Granted permissions in the context + Permissions map[NodeID]ACLMap + // Locked extensions in the context + Locked map[NodeID]*Node + + // Context state for validation Started bool Finished bool } @@ -477,8 +534,8 @@ func NewReadContext(ctx *Context) *StateContext { return &StateContext{ Type: "read", Graph: ctx, - Permissions: map[NodeID]LockMap{}, - Locked: NodeMap{}, + Permissions: map[NodeID]ACLMap{}, + Locked: map[NodeID]*Node{}, Started: false, Finished: false, } @@ -488,8 +545,8 @@ func NewWriteContext(ctx *Context) *StateContext { return &StateContext{ Type: "write", Graph: ctx, - Permissions: map[NodeID]LockMap{}, - Locked: NodeMap{}, + Permissions: map[NodeID]ACLMap{}, + Locked: map[NodeID]*Node{}, Started: false, Finished: false, } @@ -515,8 +572,8 @@ func del[K comparable](list []K, val K) []K { // Add nodes to an existing read context and call nodes_fn with new_nodes locked for read // Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read -func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { - if princ == nil || new_nodes == nil || state_fn == nil { +func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error { + if principal == nil || new_nodes == nil || state_fn == nil { return fmt.Errorf("nil passed to UseStates") } @@ -529,16 +586,16 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St context.Started = true } - new_locks := []Node{} - _, princ_locked := context.Locked[princ.ID()] + new_locks := []*Node{} + _, princ_locked := context.Locked[principal.ID] if princ_locked == false { - new_locks = append(new_locks, princ) - context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String()) - princ.LockState(false) + new_locks = append(new_locks, principal) + context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", principal.ID.String()) + principal.Lock.RLock() } - princ_permissions, princ_exists := context.Permissions[princ.ID()] - new_permissions := LockMap{} + princ_permissions, princ_exists := context.Permissions[principal.ID] + new_permissions := ACLMap{} if princ_exists == true { for id, info := range(princ_permissions) { new_permissions[id] = info @@ -550,20 +607,20 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St if node == nil { return fmt.Errorf("node in request list is nil") } - id := node.ID() + id := node.ID - if id != princ.ID() { + if id != principal.ID { _, locked := context.Locked[id] if locked == false { new_locks = append(new_locks, node) context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) - node.LockState(false) + node.Lock.RLock() } } node_permissions, node_exists := new_permissions[id] if node_exists == false { - node_permissions = LockInfo{Node: node, Resources: []string{}} + node_permissions = ACLInfo{Node: node, Resources: []string{}} } for _, resource := range(request.Resources) { @@ -575,11 +632,11 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St } if already_granted == false { - err := Allowed(context, node.Policies(), node, resource, "read", princ) + err := Allowed(context, principal, fmt.Sprintf("%s.read", resource), node) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) - n.UnlockState(false) + n.Lock.RUnlock() } return err } @@ -589,19 +646,19 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St } for _, node := range(new_locks) { - context.Locked[node.ID()] = node + context.Locked[node.ID] = node } - context.Permissions[princ.ID()] = new_permissions + context.Permissions[principal.ID] = new_permissions err = state_fn(context) - context.Permissions[princ.ID()] = princ_permissions + context.Permissions[principal.ID] = princ_permissions for _, node := range(new_locks) { - context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String()) - delete(context.Locked, node.ID()) - node.UnlockState(false) + context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID.String()) + delete(context.Locked, node.ID) + node.Lock.RUnlock() } return err @@ -609,8 +666,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St // Add nodes to an existing write context and call nodes_fn with nodes locked for read // If context is nil -func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { - if princ == nil || new_nodes == nil || state_fn == nil { +func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error { + if principal == nil || new_nodes == nil || state_fn == nil { return fmt.Errorf("nil passed to UpdateStates") } @@ -625,16 +682,16 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn final = true } - new_locks := []Node{} - _, princ_locked := context.Locked[princ.ID()] + new_locks := []*Node{} + _, princ_locked := context.Locked[principal.ID] if princ_locked == false { - new_locks = append(new_locks, princ) - context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String()) - princ.LockState(true) + new_locks = append(new_locks, principal) + context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", principal.ID.String()) + principal.Lock.Lock() } - princ_permissions, princ_exists := context.Permissions[princ.ID()] - new_permissions := LockMap{} + princ_permissions, princ_exists := context.Permissions[principal.ID] + new_permissions := ACLMap{} if princ_exists == true { for id, info := range(princ_permissions) { new_permissions[id] = info @@ -646,20 +703,20 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn if node == nil { return fmt.Errorf("node in request list is nil") } - id := node.ID() + id := node.ID - if id != princ.ID() { + if id != principal.ID { _, locked := context.Locked[id] if locked == false { new_locks = append(new_locks, node) context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) - node.LockState(true) + node.Lock.Lock() } } node_permissions, node_exists := new_permissions[id] if node_exists == false { - node_permissions = LockInfo{Node: node, Resources: []string{}} + node_permissions = ACLInfo{Node: node, Resources: []string{}} } for _, resource := range(request.Resources) { @@ -671,11 +728,11 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } if already_granted == false { - err := Allowed(context, node.Policies(), node, resource, "write", princ) + err := Allowed(context, principal, fmt.Sprintf("%s.write", resource), node) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) - n.UnlockState(true) + n.Lock.Unlock() } return err } @@ -685,10 +742,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } for _, node := range(new_locks) { - context.Locked[node.ID()] = node + context.Locked[node.ID] = node } - context.Permissions[princ.ID()] = new_permissions + context.Permissions[principal.ID] = new_permissions err = state_fn(context) @@ -699,7 +756,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn } for id, node := range(context.Locked) { context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) - node.UnlockState(true) + node.Lock.Unlock() } } diff --git a/policy.go b/policy.go index 03f6be4..8360e79 100644 --- a/policy.go +++ b/policy.go @@ -5,356 +5,119 @@ import ( "fmt" ) -// A policy represents a set of rules attached to a Node that allow principals to perform actions on it type Policy interface { - Node - // Returns true if the principal is allowed to perform the action on the resource - Allows(context *StateContext, node Node, resource string, action string, principal Node) bool + Serialize() ([]byte, error) + Allows(context *StateContext, principal *Node, action string, node *Node) bool } -type NodeActions map[string][]string -func (actions NodeActions) Allows(resource string, action string) bool { - for _, a := range(actions[""]) { - if a == action || a == "*" { - return true - } - } - - resource_actions, exists := actions[resource] - if exists == true { - for _, a := range(resource_actions) { - if a == action || a == "*" { - return true - } - } +func LoadAllNodesPolicy(ctx *Context, data []byte) (Policy, error) { + var policy AllNodesPolicy + err := json.Unmarshal(data, &policy) + if err != nil { + return policy, err } - - return false + return policy, nil } -func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) NodeActions { - if resource_actions == nil { - resource_actions = NodeActions{} - } - // Wildcard actions, all actions in "" will be allowed on all resources - if wildcard_actions == nil { - wildcard_actions = []string{} - } - resource_actions[""] = wildcard_actions - return resource_actions +type AllNodesPolicy struct { + Actions []string `json:"actions"` } -type PerNodePolicy struct { - SimpleNode - Actions map[NodeID]NodeActions +func (policy AllNodesPolicy) Type() PolicyType { + return PolicyType("simple_policy") } -type PerNodePolicyJSON struct { - SimpleNodeJSON - Actions map[string]map[string][]string `json:"actions"` +func (policy AllNodesPolicy) Serialize() ([]byte, error) { + return json.MarshalIndent(&policy, "", " ") } -func (policy *PerNodePolicy) Type() NodeType { - return NodeType("per_node_policy") +// Extension to allow a node to hold ACL policies +type ACLPolicyExtension struct { + Policies map[PolicyType]Policy } -func (policy *PerNodePolicy) Serialize() ([]byte, error) { - actions := map[string]map[string][]string{} - for principal, resource_actions := range(policy.Actions) { - actions[principal.String()] = resource_actions - } - return json.MarshalIndent(&PerNodePolicyJSON{ - SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode), - Actions: actions, - }, "", " ") +type PolicyLoadFunc func(*Context, []byte) (Policy, error) +type PolicyInfo struct { + Load PolicyLoadFunc + Type PolicyType } -func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy { - if actions == nil { - actions = map[NodeID]NodeActions{} - } - - return PerNodePolicy{ - SimpleNode: NewSimpleNode(id), - Actions: actions, - } +type ACLPolicyExtensionContext struct { + Types map[PolicyType]PolicyInfo } -var LoadPerNodePolicy = LoadJSONNode(func(id NodeID, j PerNodePolicyJSON) (Node, error) { - actions := map[NodeID]NodeActions{} - for principal_str, node_actions := range(j.Actions) { - principal_id, err := ParseID(principal_str) +func (ext ACLPolicyExtension) Serialize() ([]byte, error) { + policies := map[string][]byte{} + for name, policy := range(ext.Policies) { + ser, err := policy.Serialize() if err != nil { return nil, err } - - actions[principal_id] = node_actions + policies[string(name)] = ser } - policy := NewPerNodePolicy(id, actions) - return &policy, nil -}, func(ctx *Context, node Node, j PerNodePolicyJSON, nodes NodeMap) error { - return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes) -}) - -func (policy *PerNodePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - node_actions, exists := policy.Actions[principal.ID()] - if exists == false { - return false - } - - if node_actions.Allows(resource, action) == true { - return true - } - - return false -} - -type SimplePolicy struct { - SimpleNode - Actions NodeActions -} - -type SimplePolicyJSON struct { - SimpleNodeJSON - Actions map[string][]string `json:"actions"` -} - -func (policy *SimplePolicy) Type() NodeType { - return NodeType("simple_policy") -} - -func NewSimplePolicyJSON(policy *SimplePolicy) SimplePolicyJSON { - return SimplePolicyJSON{ - SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode), - Actions: policy.Actions, - } -} - -func (policy *SimplePolicy) Serialize() ([]byte, error) { - j := NewSimplePolicyJSON(policy) - return json.MarshalIndent(&j, "", " ") -} - -func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy { - if actions == nil { - actions = NodeActions{} - } - - return SimplePolicy{ - SimpleNode: NewSimpleNode(id), - Actions: actions, - } -} - -var LoadSimplePolicy = LoadJSONNode(func(id NodeID, j SimplePolicyJSON) (Node, error) { - policy := NewSimplePolicy(id, j.Actions) - return &policy, nil -}, func(ctx *Context, node Node, j SimplePolicyJSON, nodes NodeMap) error { - return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes) -}) - -func (policy *SimplePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - return policy.Actions.Allows(resource, action) -} - - -type DependencyPolicy struct { - SimplePolicy + return json.MarshalIndent(&struct{ + Policies map[string][]byte `json:"policies"` + }{ + Policies: policies, + }, "", " ") } - -func (policy *DependencyPolicy) Type() NodeType { - return NodeType("dependency_policy") +func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error { + return nil } -func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy { - return DependencyPolicy{ - SimplePolicy: NewSimplePolicy(id, actions), +func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) { + var j struct { + Policies map[string][]byte `json:"policies"` } -} - -func (policy *DependencyPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - lockable, ok := node.(LockableNode) - if ok == false { - return false + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err } - for _, dep := range(lockable.LockableHandle().Dependencies) { - if dep.ID() == principal.ID() { - return policy.Actions.Allows(resource, action) + policies := map[PolicyType]Policy{} + acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext) + for name, ser := range(j.Policies) { + policy_def, exists := acl_ctx.Types[PolicyType(name)] + if exists == false { + return nil, fmt.Errorf("%s is not a known policy type", name) } - } - - return false -} - -type RequirementPolicy struct { - SimplePolicy -} - - -func (policy *RequirementPolicy) Type() NodeType { - return NodeType("dependency_policy") -} - -func NewRequirementPolicy(id NodeID, actions NodeActions) RequirementPolicy { - return RequirementPolicy{ - SimplePolicy: NewSimplePolicy(id, actions), - } -} - -func (policy *RequirementPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - lockable_node, ok := node.(LockableNode) - if ok == false { - return false - } - lockable := lockable_node.LockableHandle() - - for _, req := range(lockable.Requirements) { - if req.ID() == principal.ID() { - return policy.Actions.Allows(resource, action) + policy, err := policy_def.Load(ctx, ser) + if err != nil { + return nil, err } - } - - return false -} - -type ParentPolicy struct { - SimplePolicy -} - -func (policy *ParentPolicy) Type() NodeType { - return NodeType("parent_policy") -} -func NewParentPolicy(id NodeID, actions NodeActions) ParentPolicy { - return ParentPolicy{ - SimplePolicy: NewSimplePolicy(id, actions), + policies[PolicyType(name)] = policy } -} - -func (policy *ParentPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - thread_node, ok := node.(ThreadNode) - if ok == false { - return false - } - thread := thread_node.ThreadHandle() - - if thread.Owner != nil { - if thread.Owner.ID() == principal.ID() { - return policy.Actions.Allows(resource, action) - } - } - - return false -} -type ChildrenPolicy struct { - SimplePolicy + return ACLPolicyExtension{ + Policies: policies, + }, nil } - -func (policy *ChildrenPolicy) Type() NodeType { - return NodeType("children_policy") +const ACLPolicyExtType = ExtType("ACL_POLICIES") +func (ext ACLPolicyExtension) Type() ExtType { + return ACLPolicyExtType } -func NewChildrenPolicy(id NodeID, actions NodeActions) ChildrenPolicy { - return ChildrenPolicy{ - SimplePolicy: NewSimplePolicy(id, actions), - } -} - -func (policy *ChildrenPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - thread_node, ok := node.(ThreadNode) - if ok == false { - return false - } - thread := thread_node.ThreadHandle() - - for _, info := range(thread.Children) { - if info.Child.ID() == principal.ID() { - return policy.Actions.Allows(resource, action) +// Check if the extension allows the principal to perform action on node +func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + for _, policy := range(ext.Policies) { + if policy.Allows(context, principal, action, node) == true { + return true } } - return false } -type UserOfPolicy struct { - SimplePolicy - Target GroupNode -} - -type UserOfPolicyJSON struct { - SimplePolicyJSON - Target string `json:"target"` -} - -func (policy *UserOfPolicy) Type() NodeType { - return NodeType("user_of_policy") -} - -func (policy *UserOfPolicy) Serialize() ([]byte, error) { - target := "" - if policy.Target != nil { - target = policy.Target.ID().String() - } - return json.MarshalIndent(&UserOfPolicyJSON{ - SimplePolicyJSON: NewSimplePolicyJSON(&policy.SimplePolicy), - Target: target, - }, "", " ") -} - -func NewUserOfPolicy(id NodeID, actions NodeActions) UserOfPolicy { - return UserOfPolicy{ - SimplePolicy: NewSimplePolicy(id, actions), - Target: nil, - } -} - -var LoadUserOfPolicy = LoadJSONNode(func(id NodeID, j UserOfPolicyJSON) (Node, error) { - policy := NewUserOfPolicy(id, j.Actions) - return &policy, nil -}, func(ctx *Context, policy *UserOfPolicy, j UserOfPolicyJSON, nodes NodeMap) error { - if j.Target != "" { - target_id, err := ParseID(j.Target) - if err != nil { - return err - } - - target_node, err := LoadNodeRecurse(ctx, target_id, nodes) - if err != nil { - return err - } - - target, ok := target_node.(GroupNode) - if ok == false { - return fmt.Errorf("%s is not a GroupNode", target_node.ID()) - } - policy.Target = target - return nil - } - return RestoreSimpleNode(ctx, policy, j.SimpleNodeJSON, nodes) -}) - -func (policy *UserOfPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { - if policy.Target != nil { - allowed := false - err := UseStates(context, policy.Target, NewLockInfo(policy.Target, []string{"users"}), func(context *StateContext) error { - for _, user := range(policy.Target.Users()) { - if user.ID() == principal.ID() { - allowed = policy.Actions.Allows(resource, action) - return nil - } - } - return nil - }) - if err != nil { - return false +func (policy AllNodesPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + for _, a := range(policy.Actions) { + if a == action { + return true } - return allowed } return false } + diff --git a/thread.go b/thread.go index 90e5ce0..9a7d672 100644 --- a/thread.go +++ b/thread.go @@ -8,24 +8,64 @@ import ( "encoding/json" ) +type ThreadExt struct { + Actions ThreadActions + Handlers ThreadHandlers + + SignalChan chan GraphSignal + TimeoutChan <-chan time.Time + + ChildWaits sync.WaitGroup + + ActiveLock sync.Mutex + Active bool + StateName string + + Parent *Node + Children map[NodeID]ChildInfo + + ActionQueue []QueuedAction + NextAction *QueuedAction +} + +func (ext *ThreadExt) Serialize() ([]byte, error) { + return nil, fmt.Errorf("NOT_IMPLEMENTED") +} + +const ThreadExtType = ExtType("THREAD") +func (ext *ThreadExt) Type() ExtType { + return ThreadExtType +} + +func (ext *ThreadExt) ChildList() []*Node { + ret := make([]*Node, len(ext.Children)) + i := 0 + for _, info := range(ext.Children) { + ret[i] = info.Child + i += 1 + } + + return ret +} + // Assumed that thread is already locked for signal -func (thread *Thread) Process(context *StateContext, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", thread.ID()) +func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID) var err error switch signal.Direction() { case Up: - err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { - if thread.Parent != nil { - return Signal(context, thread.Parent, thread, signal) + err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { + if ext.Parent != nil { + return Signal(context, ext.Parent, node, signal) } else { return nil } }) case Down: - err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { - for _, info := range(thread.Children) { - err := Signal(context, info.Child, thread, signal) + err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { + for _, info := range(ext.Children) { + err := Signal(context, info.Child, node, signal) if err != nil { return err } @@ -37,91 +77,121 @@ func (thread *Thread) Process(context *StateContext, signal GraphSignal) error { default: return fmt.Errorf("Invalid signal direction %d", signal.Direction()) } + ext.SignalChan <- signal + return err +} + +func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error { + thread_ext, err := GetExt[*ThreadExt](thread) if err != nil { return err } - thread.Chan <- signal - return thread.Lockable.Process(context, signal) -} - -// Requires thread and childs thread to be locked for write -func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error { - thread := node.ThreadHandle() - child := child_node.ThreadHandle() - _, is_child := thread.Children[child_node.ID()] - if is_child == false { - return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) + child_ext, err := GetExt[*ThreadExt](child) + if err != nil { + return err } - child.Parent = nil - delete(thread.Children, child.ID()) + return UpdateStates(context, principal, ACLMap{ + thread.ID: ACLInfo{thread, []string{"children"}}, + child.ID: ACLInfo{child, []string{"parent"}}, + }, func(context *StateContext) error { + _, is_child := thread_ext.Children[child.ID] + if is_child == false { + return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID) + } - return nil + delete(thread_ext.Children, child.ID) + child_ext.Parent = nil + return nil + }) } -func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool { - for _, info := range(cur.ThreadHandle().Children) { - if info.Child.ID() == target.ID() { - return true +func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) { + for _, info := range(cur.Children) { + child := info.Child + if child.ID == id { + return true, nil } - is_child := false - UpdateStates(context, cur, NewLockMap( - NewLockInfo(info.Child, []string{"children"}), - ), func(context *StateContext) error { - is_child = checkIfChild(context, target, info.Child) - return nil + + child_ext, err := GetExt[*ThreadExt](child) + if err != nil { + return false, err + } + + var is_child bool + err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error { + is_child, err = checkIfChild(context, id, child_ext) + return err }) + if err != nil { + return false, err + } if is_child { - return true + return true, nil } } - return false + return false, nil } // Links child to parent with info as the associated info // Continues the write context with princ, getting children for thread and parent for child -func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error { - if context == nil || thread_node == nil || info.Child == nil { +func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error { + if context == nil || principal == nil || thread == nil || info.Child == nil { return fmt.Errorf("invalid input") } - thread := thread_node.ThreadHandle() - child := info.Child.ThreadHandle() - child_node := info.Child - if thread.ID() == child.ID() { - return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) + child := info.Child + if thread.ID == child.ID { + return fmt.Errorf("Will not link %s as a child of itself", thread.ID) + } + + thread_ext, err := GetExt[*ThreadExt](thread) + if err != nil { + return err + } + + child_ext, err := GetExt[*ThreadExt](thread) + if err != nil { + return err } - return UpdateStates(context, princ, LockMap{ - child.ID(): LockInfo{Node: child_node, Resources: []string{"parent"}}, - thread.ID(): LockInfo{Node: thread_node, Resources: []string{"children"}}, + return UpdateStates(context, principal, ACLMap{ + child.ID: ACLInfo{Node: child, Resources: []string{"parent"}}, + thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}}, }, func(context *StateContext) error { - if child.Parent != nil { - return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) + if child_ext.Parent != nil { + return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID) } - if checkIfChild(context, thread, child) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) + is_child, err := checkIfChild(context, thread.ID, child_ext) + if err != nil { + return err + } else if is_child == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID) } - if checkIfChild(context, child, thread) == true { - return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) + is_child, err = checkIfChild(context, child.ID, thread_ext) + if err != nil { + + return err + } else if is_child == true { + return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID) } // TODO check for info types - thread.Children[child.ID()] = info - child.Parent = thread_node + thread_ext.Children[child.ID] = info + child_ext.Parent = thread return nil }) } -type ThreadAction func(*Context, ThreadNode)(string, error) +type ThreadAction func(*Context, *Node, *ThreadExt)(string, error) type ThreadActions map[string]ThreadAction -type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error) +type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error) type ThreadHandlers map[string]ThreadHandler type InfoType string @@ -129,6 +199,10 @@ func (t InfoType) String() string { return string(t) } +type ThreadInfo interface { + Serializable[InfoType] +} + // Data required by a parent thread to restore it's children type ParentThreadInfo struct { Start bool `json:"start"` @@ -136,22 +210,23 @@ type ParentThreadInfo struct { RestoreAction string `json:"restore_action"` } -func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo { - return ParentThreadInfo{ - Start: start, - StartAction: start_action, - RestoreAction: restore_action, - } +const ParentThreadInfoType = InfoType("PARENT") +func (info *ParentThreadInfo) Type() InfoType { + return ParentThreadInfoType +} + +func (info *ParentThreadInfo) Serialize() ([]byte, error) { + return json.MarshalIndent(info, "", " ") } type ChildInfo struct { - Child ThreadNode - Infos map[InfoType]interface{} + Child *Node + Infos map[InfoType]ThreadInfo } -func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo { +func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo { if infos == nil { - infos = map[InfoType]interface{}{} + infos = map[InfoType]ThreadInfo{} } return ChildInfo{ @@ -165,110 +240,21 @@ type QueuedAction struct { Action string `json:"action"` } -type ThreadNode interface { - LockableNode - ThreadHandle() *Thread -} - -type Thread struct { - Lockable - - Actions ThreadActions - Handlers ThreadHandlers - - TimeoutChan <-chan time.Time - Chan chan GraphSignal - ChildWaits sync.WaitGroup - Active bool - ActiveLock sync.Mutex - - StateName string - Parent ThreadNode - Children map[NodeID]ChildInfo - InfoTypes []InfoType - ActionQueue []QueuedAction - NextAction *QueuedAction -} - -func (thread *Thread) QueueAction(end time.Time, action string) { - thread.ActionQueue = append(thread.ActionQueue, QueuedAction{end, action}) - thread.NextAction, thread.TimeoutChan = thread.SoonestAction() -} - -func (thread *Thread) ClearActionQueue() { - thread.ActionQueue = []QueuedAction{} - thread.NextAction = nil - thread.TimeoutChan = nil -} - -func (thread *Thread) ThreadHandle() *Thread { - return thread -} - -func (thread *Thread) Type() NodeType { - return NodeType("thread") -} - -func (thread *Thread) Serialize() ([]byte, error) { - thread_json := NewThreadJSON(thread) - return json.MarshalIndent(&thread_json, "", " ") -} - -func (thread *Thread) ChildList() []ThreadNode { - ret := make([]ThreadNode, len(thread.Children)) - i := 0 - for _, info := range(thread.Children) { - ret[i] = info.Child - i += 1 - } - return ret -} - -type ThreadJSON struct { - LockableJSON - Parent string `json:"parent"` - Children map[string]map[string]interface{} `json:"children"` - ActionQueue []QueuedAction `json:"action_queue"` - StateName string `json:"state_name"` - InfoTypes []InfoType `json:"info_types"` +func (ext *ThreadExt) QueueAction(end time.Time, action string) { + ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action}) + ext.NextAction, ext.TimeoutChan = ext.SoonestAction() } -func NewThreadJSON(thread *Thread) ThreadJSON { - children := map[string]map[string]interface{}{} - for id, info := range(thread.Children) { - tmp := map[string]interface{}{} - for name, i := range(info.Infos) { - tmp[name.String()] = i - } - children[id.String()] = tmp - } - - parent_id := "" - if thread.Parent != nil { - parent_id = thread.Parent.ID().String() - } - - lockable_json := NewLockableJSON(&thread.Lockable) - - return ThreadJSON{ - Parent: parent_id, - Children: children, - ActionQueue: thread.ActionQueue, - StateName: thread.StateName, - LockableJSON: lockable_json, - InfoTypes: thread.InfoTypes, - } +func (ext *ThreadExt) ClearActionQueue() { + ext.ActionQueue = []QueuedAction{} + ext.NextAction = nil + ext.TimeoutChan = nil } -var LoadThread = LoadJSONNode(func(id NodeID, j ThreadJSON) (Node, error) { - thread := NewThread(id, j.Name, j.StateName, j.InfoTypes, BaseThreadActions, BaseThreadHandlers) - return &thread, nil -}, RestoreThread) - -func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) { +func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) { var soonest_action *QueuedAction var soonest_time time.Time - for _, action := range(thread.ActionQueue) { + for _, action := range(ext.ActionQueue) { if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil { soonest_action = &action soonest_time = action.Timeout @@ -281,55 +267,6 @@ func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) { } } -func RestoreThread(ctx *Context, thread ThreadNode, j ThreadJSON, nodes NodeMap) error { - thread_ptr := thread.ThreadHandle() - - thread_ptr.ActionQueue = j.ActionQueue - thread_ptr.NextAction, thread_ptr.TimeoutChan = thread_ptr.SoonestAction() - - if j.Parent != "" { - parent_id, err := ParseID(j.Parent) - if err != nil { - return err - } - p, err := LoadNodeRecurse(ctx, parent_id, nodes) - if err != nil { - return err - } - p_t, ok := p.(ThreadNode) - if ok == false { - return err - } - thread_ptr.Parent = p_t - } - - for id_str, info_raw := range(j.Children) { - id, err := ParseID(id_str) - if err != nil { - return err - } - - child_node, err := LoadNodeRecurse(ctx, id, nodes) - if err != nil { - return err - } - - child_t, ok := child_node.(ThreadNode) - if ok == false { - return fmt.Errorf("%+v is not a Thread as expected", child_node) - } - - parsed_info, err := DeserializeChildInfo(ctx, info_raw) - if err != nil { - return err - } - - thread_ptr.Children[id] = ChildInfo{child_t, parsed_info} - } - - return RestoreLockable(ctx, thread, j.LockableJSON, nodes) -} - var deserializers = map[InfoType]func(interface{})(interface{}, error) { "parent": func(raw interface{})(interface{}, error) { m, ok := raw.(map[string]interface{}) @@ -357,141 +294,133 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) { }, } -func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) { - ret := map[InfoType]interface{}{} - for type_str, info_raw := range(infos_raw) { - info_type := InfoType(type_str) - deserializer, exists := deserializers[info_type] - if exists == false { - return nil, fmt.Errorf("No deserializer for %s", info_type) - } - var err error - ret[info_type], err = deserializer(info_raw) - if err != nil { - return nil, err - } - } - - return ret, nil -} - -const THREAD_SIGNAL_BUFFER_SIZE = 128 -func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread { - return Thread{ - Lockable: NewLockable(id, name), - InfoTypes: info_types, - StateName: state_name, - Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE), - Children: map[NodeID]ChildInfo{}, +func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt { + return ThreadExt{ Actions: actions, Handlers: handlers, + SignalChan: make(chan GraphSignal, buffer), + TimeoutChan: nil, + Active: false, + StateName: state_name, + Parent: nil, + Children: map[NodeID]ChildInfo{}, + ActionQueue: []QueuedAction{}, + NextAction: nil, } } -func (thread *Thread) SetActive(active bool) error { - thread.ActiveLock.Lock() - defer thread.ActiveLock.Unlock() - if thread.Active == true && active == true { - return fmt.Errorf("%s is active, cannot set active", thread.ID()) - } else if thread.Active == false && active == false { - return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID()) +func (ext *ThreadExt) SetActive(active bool) error { + ext.ActiveLock.Lock() + defer ext.ActiveLock.Unlock() + if ext.Active == true && active == true { + return fmt.Errorf("alreday active, cannot set active") + } else if ext.Active == false && active == false { + return fmt.Errorf("already inactive, canot set inactive") } - thread.Active = active + ext.Active = active return nil } -func (thread *Thread) SetState(state string) error { - thread.StateName = state +func (ext *ThreadExt) SetState(state string) error { + ext.StateName = state return nil } // Requires the read permission of threads children -func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode { - if node == nil { +func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) { + if thread == nil { panic("cannot recurse through nil") } - thread := node.ThreadHandle() - if id == thread.ID() { - return thread + + if id == thread.ID { + return thread, nil } - for _, info := range thread.Children { - var result ThreadNode - UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error { - result = FindChild(context, princ, info.Child, id) - return nil - }) - if result != nil { - return result - } + thread_ext, err := GetExt[*ThreadExt](thread) + if err != nil { + return nil, err } - return nil + var found *Node = nil + err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { + for _, info := range(thread_ext.Children) { + found, err = FindChild(context, principal, info.Child, id) + if err != nil { + return err + } + if found != nil { + return nil + } + } + return nil + }) + + return found, err } -func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) { - thread.ChildWaits.Add(1) - go func(child ThreadNode) { - ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) - defer thread.ChildWaits.Done() +func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) { + thread_ext.ChildWaits.Add(1) + go func(child *Node) { + defer thread_ext.ChildWaits.Done() err := ThreadLoop(ctx, child, first_action) if err != nil { - ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID(), err) + ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err) } else { - ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID()) + ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID) } }(child) } // Main Loop for Threads, starts a write context, so cannot be called from a write or read context -func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error { - // Start the thread, error if double-started - thread := node.ThreadHandle() - ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) - err := thread.SetActive(true) +func ThreadLoop(ctx * Context, thread *Node, first_action string) error { + thread_ext, err := GetExt[*ThreadExt](thread) + if err != nil { + return err + } + + ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action) + + err = thread_ext.SetActive(true) if err != nil { ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err) return err } next_action := first_action for next_action != "" { - action, exists := thread.Actions[next_action] + action, exists := thread_ext.Actions[next_action] if exists == false { error_str := fmt.Sprintf("%s is not a valid action", next_action) return errors.New(error_str) } - ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action) - next_action, err = action(ctx, node) + ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action) + next_action, err = action(ctx, thread, thread_ext) if err != nil { return err } } - err = thread.SetActive(false) + err = thread_ext.SetActive(false) if err != nil { ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err) return err } - ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID()) + ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID) return nil } -func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { - thread := node.ThreadHandle() - ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal) +func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { + ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal) context := NewWriteContext(ctx) - err := UpdateStates(context, node, NewLockMap( - NewLockInfo(node, []string{"children"}), - ), func(context *StateContext) error { + err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { sig, ok := signal.(IDSignal) if ok == false { ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST") return nil } - info, exists := thread.Children[sig.ID] + info, exists := thread_ext.Children[sig.ID] if exists == false { ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) return nil @@ -502,7 +431,7 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin } if parent_info.Start == true { - ChildGo(ctx, thread, info.Child, parent_info.StartAction) + ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction) } return nil }) @@ -517,28 +446,26 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin // Helper function to start a child from a thread during a signal handler // Starts a write context, so cannot be called from either a write or read context -func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { +func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { sig, ok := signal.(StartChildSignal) if ok == false { return "wait", nil } - thread := node.ThreadHandle() - context := NewWriteContext(ctx) - return "wait", UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { - info, exists:= thread.Children[sig.ID] + return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { + info, exists:= thread_ext.Children[sig.ID] if exists == false { - return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID()) + return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID) } - return UpdateStates(context, node, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error { + return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error { parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) if exists == false { return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") } parent_info.Start = true - ChildGo(ctx, thread, info.Child, sig.Action) + ChildGo(ctx, thread_ext, info.Child, sig.Action) return nil }) @@ -547,19 +474,23 @@ func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string // Helper function to restore threads that should be running from a parents restore action // Starts a write context, so cannot be called from either a write or read context -func ThreadRestore(ctx * Context, node ThreadNode, start bool) error { - thread := node.ThreadHandle() +func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error { context := NewWriteContext(ctx) - return UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { - return UpdateStates(context, node, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error { - for _, info := range(thread.Children) { + return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { + return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"start", "state"}), func(context *StateContext) error { + for _, info := range(thread_ext.Children) { + child_ext, err := GetExt[*ThreadExt](info.Child) + if err != nil { + return err + } + parent_info := info.Infos["parent"].(*ParentThreadInfo) - if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" { - ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID()) + if parent_info.Start == true && child_ext.StateName != "finished" { + ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID) if start == true { - ChildGo(ctx, thread, info.Child, parent_info.StartAction) + ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction) } else { - ChildGo(ctx, thread, info.Child, parent_info.RestoreAction) + ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction) } } } @@ -571,73 +502,70 @@ func ThreadRestore(ctx * Context, node ThreadNode, start bool) error { // Helper function to be called during a threads start action, sets the thread state to started // Starts a write context, so cannot be called from either a write or read context // Returns "wait", nil on success, so the first return value can be ignored safely -func ThreadStart(ctx * Context, node ThreadNode) (string, error) { - thread := node.ThreadHandle() +func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { context := NewWriteContext(ctx) - err := UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { - err := LockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node) + err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error { + err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread) if err != nil { return err } - return thread.SetState("started") + return thread_ext.SetState("started") }) if err != nil { return "", err } context = NewReadContext(ctx) - return "wait", Signal(context, node, node, NewStatusSignal("started", node.ID())) + return "wait", Signal(context, thread, thread, NewStatusSignal("started", thread.ID)) } -func ThreadWait(ctx * Context, node ThreadNode) (string, error) { - thread := node.ThreadHandle() - ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID(), thread.ActionQueue) +func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { + ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue) for { select { - case signal := <- thread.Chan: - ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) - signal_fn, exists := thread.Handlers[signal.Type()] + case signal := <- thread_ext.SignalChan: + ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal) + signal_fn, exists := thread_ext.Handlers[signal.Type()] if exists == true { - ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) - return signal_fn(ctx, node, signal) + ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type()) + return signal_fn(ctx, thread, thread_ext, signal) } else { - ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type()) + ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type()) } - case <- thread.TimeoutChan: + case <- thread_ext.TimeoutChan: timeout_action := "" context := NewWriteContext(ctx) - err := UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"timeout"})), func(context *StateContext) error { - timeout_action = thread.NextAction.Action - thread.NextAction, thread.TimeoutChan = thread.SoonestAction() + err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error { + timeout_action = thread_ext.NextAction.Action + thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction() return nil }) if err != nil { - ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID(), err) + ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err) } - ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), timeout_action) + ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action) return timeout_action, nil } } } -func ThreadFinish(ctx *Context, node ThreadNode) (string, error) { - thread := node.ThreadHandle() +func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) { context := NewWriteContext(ctx) - return "", UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { - err := thread.SetState("finished") + return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error { + err := thread_ext.SetState("finished") if err != nil { return err } - return UnlockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node) + return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread) }) } var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError -func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { +func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := Signal(context, node, node, NewStatusSignal("aborted", node.ID())) + err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID)) if err != nil { return "", err } @@ -645,9 +573,9 @@ func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, er } // Default thread action for "stop", sends a signal and returns no error -func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { +func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { context := NewReadContext(ctx) - err := Signal(context, node, node, NewStatusSignal("stopped", node.ID())) + err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID)) return "finish", err }