Moved from inheritance to extensions

graph-rework-2
noah metz 2023-07-25 21:43:15 -06:00
parent b3f6ea67c9
commit ff813d6c2b
6 changed files with 966 additions and 1414 deletions

@ -1,31 +1,15 @@
package graphvent package graphvent
import ( import (
"github.com/graphql-go/graphql"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
"reflect"
"fmt" "fmt"
) )
// NodeLoadFunc is the footprint of the function used to create a new node in memory from persisted bytes type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
type NodeLoadFunc func(*Context, NodeID, []byte, NodeMap)(Node, error) type ExtensionInfo struct {
Load ExtensionLoadFunc
// A NodeDef is a description of a node that can be added to a Context Type ExtType
type NodeDef struct { Data interface{}
Load NodeLoadFunc
Type NodeType
GQLType *graphql.Object
Reflect reflect.Type
}
// Create a new Node def, extracting the Type and Reflect from example
func NewNodeDef(example Node, load_func NodeLoadFunc, gql_type *graphql.Object) NodeDef {
return NodeDef{
Type: example.Type(),
Load: load_func,
GQLType: gql_type,
Reflect: reflect.TypeOf(example),
}
} }
// A Context is all the data needed to run a graphvent // A Context is all the data needed to run a graphvent
@ -34,211 +18,55 @@ type Context struct {
DB * badger.DB DB * badger.DB
// Log is an interface used to record events happening // Log is an interface used to record events happening
Log Logger Log Logger
// A mapping between type hashes and their corresponding node definitions // A mapping between type hashes and their corresponding extension definitions
Types map[uint64]NodeDef Extensions map[uint64]ExtensionInfo
// GQL substructure // All loaded Nodes
GQL GQLContext Nodes map[NodeID]*Node
}
// Recreate the GQL schema after making changes
func (ctx * Context) RebuildSchema() error {
schemaConfig := graphql.SchemaConfig{
Types: ctx.GQL.TypeList,
Query: ctx.GQL.Query,
Mutation: ctx.GQL.Mutation,
Subscription: ctx.GQL.Subscription,
}
schema, err := graphql.NewSchema(schemaConfig)
if err != nil {
return err
}
ctx.GQL.Schema = schema
return nil
} }
// Add a non-node type to the gql context func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
func (ctx * Context) AddGQLType(gql_type graphql.Type) { type_hash := ext_type.Hash()
ctx.GQL.TypeList = append(ctx.GQL.TypeList, gql_type) ext, _ := ctx.Extensions[type_hash]
return ext
} }
// Add a node to a context, returns an error if the def is invalid or already exists in the context // Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx * Context) RegisterNodeType(def NodeDef) error { func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error {
if def.Load == nil { if load_fn == nil {
return fmt.Errorf("Cannot register a node without a load function: %s", def.Type) return fmt.Errorf("def has no load function")
}
if def.Reflect == nil {
return fmt.Errorf("Cannot register a node without a reflect type: %s", def.Type)
}
if def.GQLType == nil {
return fmt.Errorf("Cannot register a node without a gql type: %s", def.Type)
} }
type_hash := def.Type.Hash() type_hash := ext_type.Hash()
_, exists := ctx.Types[type_hash] _, exists := ctx.Extensions[type_hash]
if exists == true { if exists == true {
return fmt.Errorf("Cannot register node of type %s, type already exists in context", def.Type) return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type)
} }
ctx.Types[type_hash] = def ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn,
node_type := reflect.TypeOf((*Node)(nil)).Elem() Type: ext_type,
lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem()
thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem()
if def.Reflect.Implements(node_type) {
ctx.GQL.ValidNodes[def.Reflect] = def.GQLType
}
if def.Reflect.Implements(lockable_type) {
ctx.GQL.ValidLockables[def.Reflect] = def.GQLType
}
if def.Reflect.Implements(thread_type) {
ctx.GQL.ValidThreads[def.Reflect] = def.GQLType
} }
ctx.GQL.TypeList = append(ctx.GQL.TypeList, def.GQLType)
return nil return nil
} }
// Map of go types to graphql types
type ObjTypeMap map[reflect.Type]*graphql.Object
// GQL Specific Context information
type GQLContext struct {
// Generated GQL schema
Schema graphql.Schema
// List of GQL types
TypeList []graphql.Type
// Interface type maps to map go types of specific interfaces to gql types
ValidNodes ObjTypeMap
ValidLockables ObjTypeMap
ValidThreads ObjTypeMap
BaseNodeType *graphql.Object
BaseLockableType *graphql.Object
BaseThreadType *graphql.Object
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
}
// Create a new GQL context without any content
func NewGQLContext() GQLContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
ctx := GQLContext{
Schema: graphql.Schema{},
TypeList: []graphql.Type{},
ValidNodes: ObjTypeMap{},
ValidThreads: ObjTypeMap{},
ValidLockables: ObjTypeMap{},
Query: query,
Mutation: mutation,
Subscription: subscription,
BaseNodeType: GQLTypeSimpleNode.Type,
BaseLockableType: GQLTypeSimpleLockable.Type,
BaseThreadType: GQLTypeSimpleThread.Type,
}
return ctx
}
// Create a new Context with all the library content added // Create a new Context with all the library content added
func NewContext(db * badger.DB, log Logger) * Context { func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{ ctx := &Context{
GQL: NewGQLContext(),
DB: db, DB: db,
Log: log, Log: log,
Types: map[uint64]NodeDef{}, Extensions: map[uint64]ExtensionInfo{},
Nodes: map[NodeID]*Node{},
} }
err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type)) err := ctx.RegisterExtension(ACLExtType, LoadACLExtension)
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*GQLThread)(nil), LoadGQLThread, GQLTypeGQLThread.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*User)(nil), LoadUser, GQLTypeUser.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Group)(nil), LoadGroup, GQLTypeSimpleLockable.Type))
if err != nil { if err != nil {
panic(err) return nil, err
} }
err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*ParentPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*ChildrenPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*UserOfPolicy)(nil), LoadUserOfPolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
ctx.AddGQLType(GQLTypeSignal.Type)
ctx.GQL.Query.AddFieldConfig("Self", GQLQuerySelf)
ctx.GQL.Query.AddFieldConfig("User", GQLQueryUser)
ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate)
ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf)
ctx.GQL.Mutation.AddFieldConfig("abort", GQLMutationAbort)
ctx.GQL.Mutation.AddFieldConfig("startChild", GQLMutationStartChild)
err = ctx.RebuildSchema() err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return ctx return ctx, nil
} }

@ -58,7 +58,9 @@ func logTestContext(t * testing.T, components []string) * Context {
t.Fatal(err) t.Fatal(err)
} }
return NewContext(db, NewConsoleLogger(components)) ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err)
return ctx
} }
func testContext(t * testing.T) * Context { func testContext(t * testing.T) * Context {
@ -67,7 +69,9 @@ func testContext(t * testing.T) * Context {
t.Fatal(err) t.Fatal(err)
} }
return NewContext(db, NewConsoleLogger([]string{})) ctx, err := NewContext(db, NewConsoleLogger([]string{}))
fatalErr(t, err)
return ctx
} }
func fatalErr(t * testing.T, err error) { func fatalErr(t * testing.T, err error) {

@ -2,171 +2,121 @@ package graphvent
import ( import (
"fmt" "fmt"
"reflect"
"encoding/json" "encoding/json"
) )
type Listener struct { type ListenerExt struct {
Lockable
Chan chan GraphSignal Chan chan GraphSignal
} }
func (node *Listener) Type() NodeType { func NewListenerExt(buffer int) ListenerExt {
return NodeType("listener") return ListenerExt{
Chan: make(chan GraphSignal, buffer),
} }
func (node *Listener) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s", node.ID())
select {
case node.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal)
}
return node.Lockable.Process(context, signal)
} }
const LISTENER_CHANNEL_BUFFER = 1024 const ListenerExtType = ExtType("LISTENER")
func NewListener(id NodeID, name string) Listener { func (listener ListenerExt) Type() ExtType {
return Listener{ return ListenerExtType
Lockable: NewLockable(id, name),
Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER),
}
} }
var LoadListener = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error {
listener := NewListener(id, j.Name) select {
return &listener, nil case ext.Chan <- signal:
}, RestoreLockable) default:
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
type LockableNode interface {
Node
LockableHandle() *Lockable
}
// Lockable is a simple Lockable implementation that can be embedded into more complex structures
type Lockable struct {
SimpleNode
Name string
Owner LockableNode
Requirements map[NodeID]LockableNode
Dependencies map[NodeID]LockableNode
LocksHeld map[NodeID]LockableNode
} }
return nil
func (lockable *Lockable) LockableHandle() *Lockable {
return lockable
} }
func (lockable *Lockable) Type() NodeType { func (node ListenerExt) Serialize() ([]byte, error) {
return NodeType("lockable") return []byte{}, nil
} }
type LockableJSON struct { type LockableExt struct {
SimpleNodeJSON Owner *Node
Name string `json:"name"` Requirements map[NodeID]*Node
Owner string `json:"owner"` Dependencies map[NodeID]*Node
Dependencies []string `json:"dependencies"` LocksHeld map[NodeID]*Node
Requirements []string `json:"requirements"`
LocksHeld map[string]string `json:"locks_held"`
} }
func (lockable *Lockable) Serialize() ([]byte, error) { const LockableExtType = ExtType("LOCKABLE")
lockable_json := NewLockableJSON(lockable) func (ext *LockableExt) Type() ExtType {
return json.MarshalIndent(&lockable_json, "", " ") return LockableExtType
} }
func NewLockableJSON(lockable *Lockable) LockableJSON { func (ext *LockableExt) Serialize() ([]byte, error) {
requirement_ids := make([]string, len(lockable.Requirements)) requirements := make([]string, len(ext.Requirements))
req_n := 0 req_n := 0
for id, _ := range(lockable.Requirements) { for id, _ := range(ext.Requirements) {
requirement_ids[req_n] = id.String() requirements[req_n] = id.String()
req_n++ req_n++
} }
dependency_ids := make([]string, len(lockable.Dependencies)) dependencies := make([]string, len(ext.Dependencies))
dep_n := 0 dep_n := 0
for id, _ := range(lockable.Dependencies) { for id, _ := range(ext.Dependencies) {
dependency_ids[dep_n] = id.String() dependencies[dep_n] = id.String()
dep_n++ dep_n++
} }
owner_id := "" owner := ""
if lockable.Owner != nil { if ext.Owner != nil {
owner_id = lockable.Owner.ID().String() owner = ext.Owner.ID.String()
} }
locks_held := map[string]string{} locks_held := map[string]string{}
for lockable_id, node := range(lockable.LocksHeld) { for lockable_id, node := range(ext.LocksHeld) {
if node == nil { if node == nil {
locks_held[lockable_id.String()] = "" locks_held[lockable_id.String()] = ""
} else { } else {
locks_held[lockable_id.String()] = node.ID().String() locks_held[lockable_id.String()] = node.ID.String()
} }
} }
node_json := NewSimpleNodeJSON(&lockable.SimpleNode) return json.MarshalIndent(&struct{
Owner string `json:"owner"`
return LockableJSON{ Requirements []string `json:"requirements"`
SimpleNodeJSON: node_json, Dependencies []string `json:"dependencies"`
Name: lockable.Name, LocksHeld map[string]string `json:"locks_held"`
Owner: owner_id, }{
Dependencies: dependency_ids, Owner: owner,
Requirements: requirement_ids, Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held, LocksHeld: locks_held,
} }, "", " ")
}
func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode {
lockable_id := l.ID()
last_owner, exists := lockable.LocksHeld[lockable_id]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(lockable.LocksHeld, lockable_id)
return last_owner
}
func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) {
lockable_id := l.ID()
_, exists := lockable.LocksHeld[lockable_id]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
lockable.LocksHeld[lockable_id] = last_owner
} }
// Assumed that lockable is already locked for signal func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error { context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", lockable.ID())
var err error var err error
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, lockable, err = UseStates(context, node,
NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error { NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false owner_sent := false
for _, dependency := range(lockable.Dependencies) { for _, dependency := range(ext.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
Signal(context, dependency, lockable, signal) Signal(context, dependency, node, signal)
if lockable.Owner != nil { if ext.Owner != nil {
if dependency.ID() == lockable.Owner.ID() { if dependency.ID == ext.Owner.ID {
owner_sent = true owner_sent = true
} }
} }
} }
if lockable.Owner != nil && owner_sent == false { if ext.Owner != nil && owner_sent == false {
if lockable.Owner.ID() != lockable.ID() { if ext.Owner.ID != node.ID {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID()) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
return Signal(context, lockable.Owner, lockable, signal) return Signal(context, ext.Owner, node, signal)
} }
} }
return nil return nil
}) })
case Down: case Down:
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.Requirements) { for _, requirement := range(ext.Requirements) {
err := Signal(context, requirement, lockable, signal) err := Signal(context, requirement, node, signal)
if err != nil { if err != nil {
return err return err
} }
@ -176,112 +126,154 @@ func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) err
case Direct: case Direct:
err = nil err = nil
default: default:
return fmt.Errorf("invalid signal direction %d", signal.Direction()) err = fmt.Errorf("invalid signal direction %d", signal.Direction())
} }
if err != nil { if err != nil {
return err return err
} }
return lockable.SimpleNode.Process(context, signal) return nil
}
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
last_owner, exists := ext.LocksHeld[node.ID]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(ext.LocksHeld, node.ID)
return last_owner
}
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
_, exists := ext.LocksHeld[node.ID]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
ext.LocksHeld[node.ID] = last_owner
} }
// Removes requirement as a requirement from lockable // Removes requirement as a requirement from lockable
// Continues the write context with princ, getting requirents for lockable and dependencies for requirement func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks lockable_ext, err := GetExt[*LockableExt](lockable)
func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error { if err != nil {
return UpdateStates(context, princ, LockMap{ return err
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}}, }
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}}, requirement_ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
return UpdateStates(context, princ, ACLMap{
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error { }, func(context *StateContext) error {
var found Node = nil var found *Node = nil
for _, req := range(lockable.LockableHandle().Requirements) { for _, req := range(lockable_ext.Requirements) {
if requirement.ID() == req.ID() { if requirement.ID == req.ID {
found = req found = req
break break
} }
} }
if found == nil { if found == nil {
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
} }
delete(requirement.LockableHandle().Dependencies, lockable.ID()) delete(requirement_ext.Dependencies, lockable.ID)
delete(lockable.LockableHandle().Requirements, requirement.ID()) delete(lockable_ext.Requirements, requirement.ID)
return nil return nil
}) })
} }
// Link requirements as requirements to lockable // Link requirements as requirements to lockable
// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error { if lockable == nil {
if lockable_node == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
} }
lockable := lockable_node.LockableHandle()
if len(requirements) == 0 { if len(requirements) == 0 {
return nil return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
}
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
} }
found := map[NodeID]bool{} req_exts := map[NodeID]*LockableExt{}
for _, requirement := range(requirements) { for _, requirement := range(requirements) {
if requirement == nil { if requirement == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement") return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
} }
if lockable.ID() == requirement.ID() { if lockable.ID == requirement.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
} }
_, exists := found[requirement.ID()] _, exists := req_exts[requirement.ID]
if exists == true { if exists == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
}
ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
} }
found[requirement.ID()] = true req_exts[requirement.ID] = ext
} }
return UpdateStates(context, princ, NewLockMap( return UpdateStates(context, princ, NewACLMap(
NewLockInfo(lockable_node, []string{"requirements"}), NewACLInfo(lockable, []string{"requirements"}),
LockList(requirements, []string{"dependencies"}), ACLList(requirements, []string{"dependencies"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
// Check that all the requirements can be added // Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it // If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement_node := range(requirements) { for _, requirement := range(requirements) {
requirement := requirement_node.LockableHandle() requirement_ext := req_exts[requirement.ID]
for _, req_node := range(requirements) { for _, req := range(requirements) {
req := req_node.LockableHandle() if req.ID == requirement.ID {
if req.ID() == requirement.ID() {
continue continue
} }
if checkIfRequirement(context, req, requirement) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
} }
} }
if checkIfRequirement(context, lockable, requirement) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
} }
if checkIfRequirement(context, requirement, lockable) == true { is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
} }
if lockable.Owner == nil {
if lockable_ext.Owner == nil {
// If the new owner isn't locked, we can add the requirement // If the new owner isn't locked, we can add the requirement
} else if requirement.Owner == nil { } else if requirement_ext.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
} else { } else {
// If the new requirement is already locked and the owner is already locked, their owners need to match // If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement.Owner.ID() != lockable.Owner.ID() { if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
} }
} }
} }
// Update the states of the requirements // Update the states of the requirements
for _, requirement_node := range(requirements) { for _, requirement := range(requirements) {
requirement := requirement_node.LockableHandle() requirement_ext := req_exts[requirement.ID]
requirement.Dependencies[lockable.ID()] = lockable_node requirement_ext.Dependencies[lockable.ID] = lockable
lockable.Requirements[lockable.ID()] = requirement_node lockable_ext.Requirements[lockable.ID] = requirement
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
} }
// Return no error // Return no error
@ -289,74 +281,91 @@ func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode
}) })
} }
// Must be called withing update context func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool { for _, req := range(cur.Requirements) {
for _, c := range(cur.LockableHandle().Requirements) { if req.ID == id {
if c.ID() == r.ID() { return true, nil
return true
} }
is_requirement := false
UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error {
is_requirement = checkIfRequirement(context, cur, c)
return nil
})
if is_requirement { req_ext, err := GetExt[*LockableExt](req)
return true if err != nil {
return false, err
}
var is_req bool
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
is_req, err = checkIfRequirement(context, id, req_ext)
return err
})
if err != nil {
return false, err
}
if is_req == true {
return true, nil
} }
} }
return false return false, nil
} }
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error // Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call // Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error { func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
if to_lock == nil { if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
} }
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_lock) { for _, l := range(to_lock) {
var err error
if l == nil { if l == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil") return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
} }
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
} }
if new_owner_node == nil { if new_owner == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks") return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
} }
new_owner := new_owner_node.LockableHandle() new_owner_ext, err := GetExt[*LockableExt](new_owner)
if err != nil {
return err
}
// Called with no requirements to lock, success // Called with no requirements to lock, success
if len(to_lock) == 0 { if len(to_lock) == 0 {
return nil return nil
} }
return UpdateStates(context, new_owner, NewLockMap( return UpdateStates(context, new_owner, NewACLMap(
LockListM(to_lock, []string{"lock"}), ACLListM(to_lock, []string{"lock"}),
NewLockInfo(new_owner, nil), NewACLInfo(new_owner, nil),
), func(context *StateContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req_node := range(to_lock) { for _, req := range(to_lock) {
req := req_node.LockableHandle() req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
// If req is alreay locked, check that we can pass the lock // If req is alreay locked, check that we can pass the lock
if req.Owner != nil { if req_ext.Owner != nil {
owner := req.Owner owner := req_ext.Owner
if owner.ID() == new_owner.ID() { if owner.ID == new_owner.ID {
continue continue
} else { } else {
err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){ err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req.Requirements, req) return LockLockables(context, req_ext.Requirements, req)
}) })
if err != nil { if err != nil {
return err return err
} }
} }
} else { } else {
err := LockLockables(context, req.Requirements, req) err := LockLockables(context, req_ext.Requirements, req)
if err != nil { if err != nil {
return err return err
} }
@ -364,22 +373,22 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o
} }
// At this point state modification will be started, so no errors can be returned // At this point state modification will be started, so no errors can be returned
for _, req_node := range(to_lock) { for _, req := range(to_lock) {
req := req_node.LockableHandle() req_ext := req_exts[req.ID]
old_owner := req.Owner old_owner := req_ext.Owner
// If the lockable was previously unowned, update the state // If the lockable was previously unowned, update the state
if old_owner == nil { if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
req.Owner = new_owner_node req_ext.Owner = new_owner
new_owner.RecordLock(req, old_owner) new_owner_ext.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state // Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID() == new_owner.ID() { } else if old_owner.ID == new_owner.ID {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
// Otherwise update the state // Otherwise update the state
} else { } else {
req.Owner = new_owner req_ext.Owner = new_owner
new_owner.RecordLock(req, old_owner) new_owner_ext.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
} }
} }
return nil return nil
@ -387,61 +396,72 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o
} }
func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error { func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
if to_unlock == nil { if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
} }
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_unlock) { for _, l := range(to_unlock) {
if l == nil { if l == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
} }
var err error
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
} }
if old_owner_node == nil { if old_owner == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks") return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
} }
old_owner := old_owner_node.LockableHandle() old_owner_ext, err := GetExt[*LockableExt](old_owner)
if err != nil {
return err
}
// Called with no requirements to unlock, success // Called with no requirements to unlock, success
if len(to_unlock) == 0 { if len(to_unlock) == 0 {
return nil return nil
} }
return UpdateStates(context, old_owner, NewLockMap( return UpdateStates(context, old_owner, NewACLMap(
LockListM(to_unlock, []string{"lock"}), ACLListM(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, nil), NewACLInfo(old_owner, nil),
), func(context *StateContext) error { ), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements // First loop is to check that the states can be locked, and locks all requirements
for _, req_node := range(to_unlock) { for _, req := range(to_unlock) {
req := req_node.LockableHandle() req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
// Check if the owner is correct // Check if the owner is correct
if req.Owner != nil { if req_ext.Owner != nil {
if req.Owner.ID() != old_owner.ID() { if req_ext.Owner.ID != old_owner.ID {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID()) return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
} }
} else { } else {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
} }
err := UnlockLockables(context, req.Requirements, req) err := UnlockLockables(context, req_ext.Requirements, req)
if err != nil { if err != nil {
return err return err
} }
} }
// At this point state modification will be started, so no errors can be returned // At this point state modification will be started, so no errors can be returned
for _, req_node := range(to_unlock) { for _, req := range(to_unlock) {
req := req_node.LockableHandle() req_ext := req_exts[req.ID]
new_owner := old_owner.RecordUnlock(req) new_owner := old_owner_ext.RecordUnlock(req)
req.Owner = new_owner req_ext.Owner = new_owner
if new_owner == nil { if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
} else { } else {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID()) context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
} }
} }
@ -449,103 +469,55 @@ func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, o
}) })
} }
var LoadLockable = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { func RestoreNode(ctx *Context, id_str string) (*Node, error) {
lockable := NewLockable(id, j.Name) id, err := ParseID(id_str)
return &lockable, nil
}, RestoreLockable)
func NewLockable(id NodeID, name string) Lockable {
return Lockable{
SimpleNode: NewSimpleNode(id),
Name: name,
Owner: nil,
Requirements: map[NodeID]LockableNode{},
Dependencies: map[NodeID]LockableNode{},
LocksHeld: map[NodeID]LockableNode{},
}
}
// Helper function to load links when loading a struct that embeds Lockable
func RestoreLockable(ctx * Context, lockable LockableNode, j LockableJSON, nodes NodeMap) error {
lockable_ptr := lockable.LockableHandle()
if j.Owner != "" {
owner_id, err := ParseID(j.Owner)
if err != nil { if err != nil {
return err return nil, err
}
owner_node, err := LoadNodeRecurse(ctx, owner_id, nodes)
if err != nil {
return err
}
owner, ok := owner_node.(LockableNode)
if ok == false {
return fmt.Errorf("%s is not a Lockable", j.Owner)
}
lockable_ptr.Owner = owner
} }
for _, dep_str := range(j.Dependencies) { return LoadNode(ctx, id)
dep_id, err := ParseID(dep_str)
if err != nil {
return err
} }
dep_node, err := LoadNodeRecurse(ctx, dep_id, nodes)
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
nodes := NodeMap{}
for id_str_1, id_str_2 := range(ids) {
id_1, err := ParseID(id_str_1)
if err != nil { if err != nil {
return err return nil, err
}
dep, ok := dep_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
}
ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep))
lockable_ptr.Dependencies[dep_id] = dep
} }
for _, req_str := range(j.Requirements) { id_2, err := ParseID(id_str_2)
req_id, err := ParseID(req_str)
if err != nil { if err != nil {
return err return nil, err
} }
req_node, err := LoadNodeRecurse(ctx, req_id, nodes)
node_1, err := LoadNode(ctx, id_1)
if err != nil { if err != nil {
return err return nil, err
}
req, ok := req_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", req_node)
}
lockable_ptr.Requirements[req_id] = req
} }
for l_id_str, h_str := range(j.LocksHeld) { node_2, err := LoadNode(ctx, id_2)
l_id, err := ParseID(l_id_str)
l, err := LoadNodeRecurse(ctx, l_id, nodes)
if err != nil { if err != nil {
return err return nil, err
} }
l_l, ok := l.(LockableNode)
if ok == false { nodes[node_1.ID] = node_2
return fmt.Errorf("%s is not a Lockable", l.ID())
} }
var h_l LockableNode return nodes, nil
if h_str != "" {
h_id, err := ParseID(h_str)
if err != nil {
return err
} }
h_node, err := LoadNodeRecurse(ctx, h_id, nodes)
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
nodes := NodeMap{}
for _, id_str := range(ids) {
node, err := RestoreNode(ctx, id_str)
if err != nil { if err != nil {
return err return nil, err
} }
h, ok := h_node.(LockableNode) nodes[node.ID] = node
if ok == false {
return err
}
h_l = h
}
lockable_ptr.RecordLock(l_l, h_l)
} }
return RestoreSimpleNode(ctx, lockable, j.SimpleNodeJSON, nodes) return nodes, nil
} }

@ -2,6 +2,7 @@ package graphvent
import ( import (
"sync" "sync"
"reflect"
"github.com/google/uuid" "github.com/google/uuid"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
"fmt" "fmt"
@ -31,6 +32,12 @@ func (id NodeID) String() string {
return (uuid.UUID)(id).String() return (uuid.UUID)(id).String()
} }
// Ignore the error since we're enforcing 16 byte length at compile time
func IDFromBytes(bytes [16]byte) NodeID {
id, _ := uuid.FromBytes(bytes[:])
return NodeID(id)
}
func ParseID(str string) (NodeID, error) { func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str) id_uuid, err := uuid.Parse(str)
if err != nil { if err != nil {
@ -45,230 +52,291 @@ func KeyID(pub *ecdsa.PublicKey) NodeID {
return NodeID(str) return NodeID(str)
} }
// Types are how nodes are associated with structs at runtime(and from the DB)
type NodeType string
func (node_type NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type)))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
// Generate a random NodeID // Generate a random NodeID
func RandID() NodeID { func RandID() NodeID {
return NodeID(uuid.New()) return NodeID(uuid.New())
} }
type Node interface { type Serializable[I comparable] interface {
ID() NodeID Type() I
Type() NodeType
Serialize() ([]byte, error) Serialize() ([]byte, error)
LockState(write bool)
UnlockState(write bool)
Process(context *StateContext, signal GraphSignal) error
Policies() []Policy
NodeHandle() *SimpleNode
} }
type SimpleNode struct { // NodeExtensions are additional data that can be attached to nodes, and used in node functions
Id NodeID type Extension interface {
state_mutex sync.RWMutex Serializable[ExtType]
PolicyMap map[NodeID]Policy // Send a signal to this extension to process,
// this typically triggers signals to be sent to nodes linked in the extension
Process(context *StateContext, node *Node, signal GraphSignal) error
} }
func (node *SimpleNode) NodeHandle() *SimpleNode { // Nodes represent an addressible group of extensions
return node type Node struct {
ID NodeID
Lock sync.RWMutex
ExtensionMap map[ExtType]Extension
} }
func NewSimpleNode(id NodeID) SimpleNode { func GetExt[T Extension](node *Node) (T, error) {
return SimpleNode{ var zero T
Id: id, ext_type := zero.Type()
PolicyMap: map[NodeID]Policy{}, ext, exists := node.ExtensionMap[ext_type]
} if exists == false {
return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
} }
type SimpleNodeJSON struct { ret, ok := ext.(T)
Policies []string `json:"policies"` if ok == false {
return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
} }
func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error { return ret, nil
context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %s", node.Id, signal)
return nil
} }
func (node *SimpleNode) ID() NodeID { // The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
return node.Id type ACLExtension struct {
Delegations NodeMap
} }
func (node *SimpleNode) Type() NodeType { func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return NodeType("simple_node") return nil
} }
func (node *SimpleNode) Serialize() ([]byte, error) { func LoadACLExtension(ctx *Context, data []byte) (Extension, error) {
j := NewSimpleNodeJSON(node) var j struct {
return json.MarshalIndent(&j, "", " ") Delegations []string `json:"delegation"`
} }
func (node *SimpleNode) LockState(write bool) { err := json.Unmarshal(data, &j)
if write == true { if err != nil {
node.state_mutex.Lock() return nil, err
} else {
node.state_mutex.RLock()
}
} }
func (node *SimpleNode) UnlockState(write bool) { delegations := NodeMap{}
if write == true { for _, str := range(j.Delegations) {
node.state_mutex.Unlock() id, err := ParseID(str)
} else { if err != nil {
node.state_mutex.RUnlock() return nil, err
}
} }
func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON { node, err := LoadNode(ctx, id)
policy_ids := make([]string, len(node.PolicyMap)) if err != nil {
i := 0 return nil, err
for id, _ := range(node.PolicyMap) {
policy_ids[i] = id.String()
i += 1
} }
return SimpleNodeJSON{ delegations[id] = node
Policies: policy_ids,
}
} }
func RestoreSimpleNode(ctx *Context, node Node, j SimpleNodeJSON, nodes NodeMap) error { return ACLExtension{
node_ptr := node.NodeHandle() Delegations: delegations,
for _, policy_str := range(j.Policies) { }, nil
policy_id, err := ParseID(policy_str)
if err != nil {
return err
} }
policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes) func (ext ACLExtension) Serialize() ([]byte, error) {
if err != nil { delegations := make([]string, len(ext.Delegations))
return err i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
} }
policy, ok := policy_ptr.(Policy) return json.MarshalIndent(&struct{
if ok == false { Delegations []string `json:"delegations"`
return fmt.Errorf("%s is not a Policy", policy_id) }{
} Delegations: delegations,
node_ptr.PolicyMap[policy_id] = policy }, "", " ")
} }
return nil const ACLExtType = ExtType("ACL")
func (extension ACLExtension) Type() ExtType {
return ACLExtType
} }
func LoadJSONNode[J any, N Node](init_func func(NodeID, J)(Node, error), restore_func func(*Context, N, J, NodeMap)error)func(*Context, NodeID, []byte, NodeMap)(Node, error) { func (node *Node) Serialize() ([]byte, error) {
return func(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { extensions := make([]ExtensionDB, len(node.ExtensionMap))
var j J node_db := NodeDB{
err := json.Unmarshal(data, &j) Header: NodeDBHeader{
if err != nil { Magic: NODE_DB_MAGIC,
return nil, err NumExtensions: uint32(len(extensions)),
},
Extensions: extensions,
} }
node, err := init_func(id, j) i := 0
if err != nil { for ext_type, info := range(node.ExtensionMap) {
return nil, err ser, err := info.Serialize()
}
nodes[id] = node
err = restore_func(ctx, node.(N), j, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node_db.Extensions[i] = ExtensionDB{
return node, nil Header: ExtensionDBHeader{
TypeHash: ext_type.Hash(),
Length: uint64(len(ser)),
},
Data: ser,
} }
i += 1
} }
var LoadSimpleNode = LoadJSONNode(func(id NodeID, j SimpleNodeJSON) (Node, error) { return node_db.Serialize(), nil
node := NewSimpleNode(id)
return &node, nil
}, RestoreSimpleNode)
func (node *SimpleNode) Policies() []Policy {
ret := make([]Policy, len(node.PolicyMap))
i := 0
for _, policy := range(node.PolicyMap) {
ret[i] = policy
i += 1
} }
return ret func NewNode(id NodeID) Node {
return Node{
ID: id,
ExtensionMap: map[ExtType]Extension{},
}
} }
func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error { func Allowed(context *StateContext, principal *Node, action string, node *Node) error {
if princ == nil { if principal == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s", principal.ID, node.ID, action)
return fmt.Errorf("nil is not allowed to perform any actions") return fmt.Errorf("nil is not allowed to perform any actions")
} }
if node.ID() == princ.ID() {
return nil ext, exists := node.ExtensionMap[ACLExtType]
if exists == false {
return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID)
}
acl_ext := ext.(ACLExtension)
for _, policy_node := range(acl_ext.Delegations) {
ext, exists := policy_node.ExtensionMap[ACLPolicyExtType]
if exists == false {
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension")
continue
} }
for _, policy := range(policies) { policy_ext := ext.(ACLPolicyExtension)
if policy.Allows(context, node, resource, action, princ) == true { if policy_ext.Allows(context, principal, action, node) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action)
return nil return nil
} }
} }
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", principal.ID, node.ID, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID()) return fmt.Errorf("%s is not allowed to perform %s on %s", principal.ID, action, node.ID)
} }
// Check that princ is allowed to signal this action,
// then send the signal to all the extensions of the node
func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, signal.String())
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error {
// send it a notification that it was closed and then close it return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node)
func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
err := UseStates(context, princ, NewLockInfo(node, []string{}), func(context *StateContext) error {
return Allowed(context, node.Policies(), node, "signal", signal.Type(), princ)
}) })
for _, ext := range(node.ExtensionMap) {
err = ext.Process(context, node, signal)
if err != nil { if err != nil {
return nil return nil
} }
return node.Process(context, signal)
} }
func AttachPolicies(ctx *Context, node Node, policies ...Policy) error {
context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"policies"}), LockList(policies, nil)), func(context *StateContext) error {
for _, policy := range(policies) {
node.NodeHandle().PolicyMap[policy.ID()] = policy
}
return nil return nil
})
} }
// Magic first four bytes of serialized DB content, stored big endian // Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14 const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
const NODE_DB_HEADER_LEN = 12 const NODE_DB_HEADER_LEN = 8
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type DBHeader struct { type NodeDBHeader struct {
Magic uint32 Magic uint32
TypeHash uint64 NumExtensions uint32
}
type NodeDB struct {
Header NodeDBHeader
Extensions []ExtensionDB
}
//TODO: add size safety checks
func NewNodeDB(data []byte) (NodeDB, error) {
var zero NodeDB
ptr := 0
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
ptr += NODE_DB_HEADER_LEN
if magic != NODE_DB_MAGIC {
return zero, fmt.Errorf("header has incorrect magic 0x%x", magic)
}
extensions := make([]ExtensionDB, num_extensions)
for i, _ := range(extensions) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(EXTENSION_DB_HEADER_LEN)
data_end := data_start + length
ext_data := cur[data_start:data_end]
extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: type_hash,
Length: length,
},
Data: ext_data,
} }
func (header DBHeader) Serialize() []byte { ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
NumExtensions: num_extensions,
},
Extensions: extensions,
}, nil
}
func (header NodeDBHeader) Serialize() []byte {
if header.Magic != NODE_DB_MAGIC { if header.Magic != NODE_DB_MAGIC {
panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic)) panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic))
} }
ret := make([]byte, NODE_DB_HEADER_LEN) ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint64(ret[4:12], header.TypeHash) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
return ret return ret
} }
func NewDBHeader(node_type NodeType) DBHeader { func (node NodeDB) Serialize() []byte {
return DBHeader{ ser := node.Header.Serialize()
Magic: NODE_DB_MAGIC, for _, extension := range(node.Extensions) {
TypeHash: node_type.Hash(), ser = append(ser, extension.Serialize()...)
} }
return ser
}
func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[8:16], header.Length)
return ret
}
func (extension ExtensionDB) Serialize() []byte {
header_bytes := extension.Header.Serialize()
return append(header_bytes, extension.Data...)
}
const EXTENSION_DB_HEADER_LEN = 16
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
} }
// Write multiple nodes to the database in a single transaction // Write multiple nodes to the database in a single transaction
@ -283,27 +351,21 @@ func WriteNodes(context *StateContext) error {
serialized_bytes := make([][]byte, len(context.Locked)) serialized_bytes := make([][]byte, len(context.Locked))
serialized_ids := make([][]byte, len(context.Locked)) serialized_ids := make([][]byte, len(context.Locked))
i := 0 i := 0
for _, node := range(context.Locked) { // TODO, just write states from the context, and store the current states in the context
for id, _ := range(context.Locked) {
node, _ := context.Graph.Nodes[id]
if node == nil { if node == nil {
return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node") return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node, maybe node isn't in the context")
} }
ser, err := node.Serialize() ser, err := node.Serialize()
if err != nil { if err != nil {
return fmt.Errorf("DB_SERIALIZE_ERROR: %s", err) return fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
} }
header := NewDBHeader(node.Type()) id_ser := node.ID.Serialize()
db_data := append(header.Serialize(), ser...)
context.Graph.Log.Logf("db", "DB_WRITING_TYPE: %s - %+v %+v: %+v", node.ID(), node.Type(), header, node)
if err != nil {
return err
}
id_ser := node.ID().Serialize()
serialized_bytes[i] = db_data serialized_bytes[i] = ser
serialized_ids[i] = id_ser serialized_ids[i] = id_ser
i++ i++
@ -320,8 +382,13 @@ func WriteNodes(context *StateContext) error {
}) })
} }
// Get the bytes associates with `id` from the database after unwrapping the header, or error // Recursively load a node from the database.
func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) { func LoadNode(ctx * Context, id NodeID) (*Node, error) {
node, exists := ctx.Nodes[id]
if exists == true {
return node,nil
}
var bytes []byte var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error { err := ctx.DB.View(func(txn *badger.Txn) error {
item, err := txn.Get(id.Serialize()) item, err := txn.Get(id.Serialize())
@ -334,80 +401,51 @@ func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) {
return nil return nil
}) })
}) })
if err != nil { if err != nil {
ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err) return nil, err
return 0, nil, err
}
if len(bytes) < NODE_DB_HEADER_LEN {
return 0, nil, fmt.Errorf("header for %s is %d/%d bytes", id, len(bytes), NODE_DB_HEADER_LEN)
}
header := DBHeader{}
header.Magic = binary.BigEndian.Uint32(bytes[0:4])
header.TypeHash = binary.BigEndian.Uint64(bytes[4:12])
if header.Magic != NODE_DB_MAGIC {
return 0, nil, fmt.Errorf("header for %s, invalid magic 0x%x", id, header.Magic)
}
node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN)
copy(node_bytes, bytes[NODE_DB_HEADER_LEN:])
ctx.Log.Logf("db", "DB_READ: %s %+v - %s", id, header, string(bytes))
return header.TypeHash, node_bytes, nil
}
// Load a Node from the database by ID
func LoadNode(ctx * Context, id NodeID) (Node, error) {
nodes := NodeMap{}
return LoadNodeRecurse(ctx, id, nodes)
} }
// Parse the bytes from the DB
// Recursively load a node from the database. node_db, err := NewNodeDB(bytes)
// It's expected that node_type.Load adds the newly loaded node to nodes before calling LoadNodeRecurse again.
func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) {
node, exists := nodes[id]
if exists == false {
type_hash, bytes, err := readNodeBytes(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node_type, exists := ctx.Types[type_hash] // Create the blank node with the ID, and add it to the context
ctx.Log.Logf("db", "DB_LOADING_TYPE: %s - %+v", id, node_type) new_node := NewNode(id)
if exists == false { node = &new_node
return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes) ctx.Nodes[id] = node
}
if node_type.Load == nil { // Parse each of the extensions from the db
return nil, fmt.Errorf("0x%x is an invalid node type, nil Load", type_hash) for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
def, known := ctx.Extensions[type_hash]
if known == false {
return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash)
} }
extension, err := def.Load(ctx, ext_db.Data)
node, err = node_type.Load(ctx, id, bytes, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node.ExtensionMap[def.Type] = extension
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash)
}
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
}
return node, nil return node, nil
} }
func NewLockInfo(node Node, resources []string) LockMap { func NewACLInfo(node *Node, resources []string) ACLMap {
return LockMap{ return ACLMap{
node.ID(): LockInfo{ node.ID: ACLInfo{
Node: node, Node: node,
Resources: resources, Resources: resources,
}, },
} }
} }
func NewLockMap(requests ...LockMap) LockMap { func NewACLMap(requests ...ACLMap) ACLMap {
reqs := LockMap{} reqs := ACLMap{}
for _, req := range(requests) { for _, req := range(requests) {
for id, info := range(req) { for id, info := range(req) {
reqs[id] = info reqs[id] = info
@ -416,10 +454,10 @@ func NewLockMap(requests ...LockMap) LockMap {
return reqs return reqs
} }
func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap { func ACLListM(m map[NodeID]*Node, resources[]string) ACLMap {
reqs := LockMap{} reqs := ACLMap{}
for _, node := range(m) { for _, node := range(m) {
reqs[node.ID()] = LockInfo{ reqs[node.ID] = ACLInfo{
Node: node, Node: node,
Resources: resources, Resources: resources,
} }
@ -427,10 +465,10 @@ func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap {
return reqs return reqs
} }
func LockList[K Node](list []K, resources []string) LockMap { func ACLList(list []*Node, resources []string) ACLMap {
reqs := LockMap{} reqs := ACLMap{}
for _, node := range(list) { for _, node := range(list) {
reqs[node.ID()] = LockInfo{ reqs[node.ID] = ACLInfo{
Node: node, Node: node,
Resources: resources, Resources: resources,
} }
@ -438,21 +476,40 @@ func LockList[K Node](list []K, resources []string) LockMap {
return reqs return reqs
} }
type PolicyType string
func (policy PolicyType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type ExtType string
func (ext ExtType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("EXTENSION: %s", string(ext))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type NodeMap map[NodeID]Node type NodeMap map[NodeID]*Node
type LockInfo struct { type ACLInfo struct {
Node Node Node *Node
Resources []string Resources []string
} }
type LockMap map[NodeID]LockInfo type ACLMap map[NodeID]ACLInfo
type ExtMap map[uint64]Extension
// Context of running state usage(read/write)
type StateContext struct { type StateContext struct {
// Type of the state context
Type string Type string
// The wrapped graph context
Graph *Context Graph *Context
Permissions map[NodeID]LockMap // Granted permissions in the context
Locked NodeMap Permissions map[NodeID]ACLMap
// Locked extensions in the context
Locked map[NodeID]*Node
// Context state for validation
Started bool Started bool
Finished bool Finished bool
} }
@ -477,8 +534,8 @@ func NewReadContext(ctx *Context) *StateContext {
return &StateContext{ return &StateContext{
Type: "read", Type: "read",
Graph: ctx, Graph: ctx,
Permissions: map[NodeID]LockMap{}, Permissions: map[NodeID]ACLMap{},
Locked: NodeMap{}, Locked: map[NodeID]*Node{},
Started: false, Started: false,
Finished: false, Finished: false,
} }
@ -488,8 +545,8 @@ func NewWriteContext(ctx *Context) *StateContext {
return &StateContext{ return &StateContext{
Type: "write", Type: "write",
Graph: ctx, Graph: ctx,
Permissions: map[NodeID]LockMap{}, Permissions: map[NodeID]ACLMap{},
Locked: NodeMap{}, Locked: map[NodeID]*Node{},
Started: false, Started: false,
Finished: false, Finished: false,
} }
@ -515,8 +572,8 @@ func del[K comparable](list []K, val K) []K {
// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read // Add nodes to an existing read context and call nodes_fn with new_nodes locked for read
// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read // Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read
func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil { if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UseStates") return fmt.Errorf("nil passed to UseStates")
} }
@ -529,16 +586,16 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
context.Started = true context.Started = true
} }
new_locks := []Node{} new_locks := []*Node{}
_, princ_locked := context.Locked[princ.ID()] _, princ_locked := context.Locked[principal.ID]
if princ_locked == false { if princ_locked == false {
new_locks = append(new_locks, princ) new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String()) context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", principal.ID.String())
princ.LockState(false) principal.Lock.RLock()
} }
princ_permissions, princ_exists := context.Permissions[princ.ID()] princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := LockMap{} new_permissions := ACLMap{}
if princ_exists == true { if princ_exists == true {
for id, info := range(princ_permissions) { for id, info := range(princ_permissions) {
new_permissions[id] = info new_permissions[id] = info
@ -550,20 +607,20 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if node == nil { if node == nil {
return fmt.Errorf("node in request list is nil") return fmt.Errorf("node in request list is nil")
} }
id := node.ID() id := node.ID
if id != princ.ID() { if id != principal.ID {
_, locked := context.Locked[id] _, locked := context.Locked[id]
if locked == false { if locked == false {
new_locks = append(new_locks, node) new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.LockState(false) node.Lock.RLock()
} }
} }
node_permissions, node_exists := new_permissions[id] node_permissions, node_exists := new_permissions[id]
if node_exists == false { if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}} node_permissions = ACLInfo{Node: node, Resources: []string{}}
} }
for _, resource := range(request.Resources) { for _, resource := range(request.Resources) {
@ -575,11 +632,11 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
} }
if already_granted == false { if already_granted == false {
err := Allowed(context, node.Policies(), node, resource, "read", princ) err := Allowed(context, principal, fmt.Sprintf("%s.read", resource), node)
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.UnlockState(false) n.Lock.RUnlock()
} }
return err return err
} }
@ -589,19 +646,19 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
} }
for _, node := range(new_locks) { for _, node := range(new_locks) {
context.Locked[node.ID()] = node context.Locked[node.ID] = node
} }
context.Permissions[princ.ID()] = new_permissions context.Permissions[principal.ID] = new_permissions
err = state_fn(context) err = state_fn(context)
context.Permissions[princ.ID()] = princ_permissions context.Permissions[principal.ID] = princ_permissions
for _, node := range(new_locks) { for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String()) context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID.String())
delete(context.Locked, node.ID()) delete(context.Locked, node.ID)
node.UnlockState(false) node.Lock.RUnlock()
} }
return err return err
@ -609,8 +666,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
// Add nodes to an existing write context and call nodes_fn with nodes locked for read // Add nodes to an existing write context and call nodes_fn with nodes locked for read
// If context is nil // If context is nil
func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error { func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil { if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UpdateStates") return fmt.Errorf("nil passed to UpdateStates")
} }
@ -625,16 +682,16 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
final = true final = true
} }
new_locks := []Node{} new_locks := []*Node{}
_, princ_locked := context.Locked[princ.ID()] _, princ_locked := context.Locked[principal.ID]
if princ_locked == false { if princ_locked == false {
new_locks = append(new_locks, princ) new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String()) context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", principal.ID.String())
princ.LockState(true) principal.Lock.Lock()
} }
princ_permissions, princ_exists := context.Permissions[princ.ID()] princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := LockMap{} new_permissions := ACLMap{}
if princ_exists == true { if princ_exists == true {
for id, info := range(princ_permissions) { for id, info := range(princ_permissions) {
new_permissions[id] = info new_permissions[id] = info
@ -646,20 +703,20 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if node == nil { if node == nil {
return fmt.Errorf("node in request list is nil") return fmt.Errorf("node in request list is nil")
} }
id := node.ID() id := node.ID
if id != princ.ID() { if id != principal.ID {
_, locked := context.Locked[id] _, locked := context.Locked[id]
if locked == false { if locked == false {
new_locks = append(new_locks, node) new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.LockState(true) node.Lock.Lock()
} }
} }
node_permissions, node_exists := new_permissions[id] node_permissions, node_exists := new_permissions[id]
if node_exists == false { if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}} node_permissions = ACLInfo{Node: node, Resources: []string{}}
} }
for _, resource := range(request.Resources) { for _, resource := range(request.Resources) {
@ -671,11 +728,11 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
if already_granted == false { if already_granted == false {
err := Allowed(context, node.Policies(), node, resource, "write", princ) err := Allowed(context, principal, fmt.Sprintf("%s.write", resource), node)
if err != nil { if err != nil {
for _, n := range(new_locks) { for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.UnlockState(true) n.Lock.Unlock()
} }
return err return err
} }
@ -685,10 +742,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
for _, node := range(new_locks) { for _, node := range(new_locks) {
context.Locked[node.ID()] = node context.Locked[node.ID] = node
} }
context.Permissions[princ.ID()] = new_permissions context.Permissions[principal.ID] = new_permissions
err = state_fn(context) err = state_fn(context)
@ -699,7 +756,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
} }
for id, node := range(context.Locked) { for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String()) context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.UnlockState(true) node.Lock.Unlock()
} }
} }

@ -5,356 +5,119 @@ import (
"fmt" "fmt"
) )
// A policy represents a set of rules attached to a Node that allow principals to perform actions on it
type Policy interface { type Policy interface {
Node Serialize() ([]byte, error)
// Returns true if the principal is allowed to perform the action on the resource Allows(context *StateContext, principal *Node, action string, node *Node) bool
Allows(context *StateContext, node Node, resource string, action string, principal Node) bool
} }
type NodeActions map[string][]string func LoadAllNodesPolicy(ctx *Context, data []byte) (Policy, error) {
func (actions NodeActions) Allows(resource string, action string) bool { var policy AllNodesPolicy
for _, a := range(actions[""]) { err := json.Unmarshal(data, &policy)
if a == action || a == "*" { if err != nil {
return true return policy, err
}
}
resource_actions, exists := actions[resource]
if exists == true {
for _, a := range(resource_actions) {
if a == action || a == "*" {
return true
}
}
}
return false
}
func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) NodeActions {
if resource_actions == nil {
resource_actions = NodeActions{}
}
// Wildcard actions, all actions in "" will be allowed on all resources
if wildcard_actions == nil {
wildcard_actions = []string{}
} }
resource_actions[""] = wildcard_actions return policy, nil
return resource_actions
} }
type PerNodePolicy struct { type AllNodesPolicy struct {
SimpleNode Actions []string `json:"actions"`
Actions map[NodeID]NodeActions
} }
type PerNodePolicyJSON struct { func (policy AllNodesPolicy) Type() PolicyType {
SimpleNodeJSON return PolicyType("simple_policy")
Actions map[string]map[string][]string `json:"actions"`
} }
func (policy *PerNodePolicy) Type() NodeType { func (policy AllNodesPolicy) Serialize() ([]byte, error) {
return NodeType("per_node_policy") return json.MarshalIndent(&policy, "", " ")
} }
func (policy *PerNodePolicy) Serialize() ([]byte, error) { // Extension to allow a node to hold ACL policies
actions := map[string]map[string][]string{} type ACLPolicyExtension struct {
for principal, resource_actions := range(policy.Actions) { Policies map[PolicyType]Policy
actions[principal.String()] = resource_actions
} }
return json.MarshalIndent(&PerNodePolicyJSON{
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions,
}, "", " ")
}
func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy { type PolicyLoadFunc func(*Context, []byte) (Policy, error)
if actions == nil { type PolicyInfo struct {
actions = map[NodeID]NodeActions{} Load PolicyLoadFunc
Type PolicyType
} }
return PerNodePolicy{ type ACLPolicyExtensionContext struct {
SimpleNode: NewSimpleNode(id), Types map[PolicyType]PolicyInfo
Actions: actions,
}
} }
var LoadPerNodePolicy = LoadJSONNode(func(id NodeID, j PerNodePolicyJSON) (Node, error) { func (ext ACLPolicyExtension) Serialize() ([]byte, error) {
actions := map[NodeID]NodeActions{} policies := map[string][]byte{}
for principal_str, node_actions := range(j.Actions) { for name, policy := range(ext.Policies) {
principal_id, err := ParseID(principal_str) ser, err := policy.Serialize()
if err != nil { if err != nil {
return nil, err return nil, err
} }
policies[string(name)] = ser
actions[principal_id] = node_actions
}
policy := NewPerNodePolicy(id, actions)
return &policy, nil
}, func(ctx *Context, node Node, j PerNodePolicyJSON, nodes NodeMap) error {
return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes)
})
func (policy *PerNodePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
node_actions, exists := policy.Actions[principal.ID()]
if exists == false {
return false
}
if node_actions.Allows(resource, action) == true {
return true
}
return false
}
type SimplePolicy struct {
SimpleNode
Actions NodeActions
}
type SimplePolicyJSON struct {
SimpleNodeJSON
Actions map[string][]string `json:"actions"`
}
func (policy *SimplePolicy) Type() NodeType {
return NodeType("simple_policy")
}
func NewSimplePolicyJSON(policy *SimplePolicy) SimplePolicyJSON {
return SimplePolicyJSON{
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: policy.Actions,
}
}
func (policy *SimplePolicy) Serialize() ([]byte, error) {
j := NewSimplePolicyJSON(policy)
return json.MarshalIndent(&j, "", " ")
}
func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy {
if actions == nil {
actions = NodeActions{}
}
return SimplePolicy{
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
}
var LoadSimplePolicy = LoadJSONNode(func(id NodeID, j SimplePolicyJSON) (Node, error) {
policy := NewSimplePolicy(id, j.Actions)
return &policy, nil
}, func(ctx *Context, node Node, j SimplePolicyJSON, nodes NodeMap) error {
return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes)
})
func (policy *SimplePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
return policy.Actions.Allows(resource, action)
}
type DependencyPolicy struct {
SimplePolicy
}
func (policy *DependencyPolicy) Type() NodeType {
return NodeType("dependency_policy")
}
func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy {
return DependencyPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
}
}
func (policy *DependencyPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
lockable, ok := node.(LockableNode)
if ok == false {
return false
}
for _, dep := range(lockable.LockableHandle().Dependencies) {
if dep.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
}
}
return false
}
type RequirementPolicy struct {
SimplePolicy
}
func (policy *RequirementPolicy) Type() NodeType {
return NodeType("dependency_policy")
}
func NewRequirementPolicy(id NodeID, actions NodeActions) RequirementPolicy {
return RequirementPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
}
}
func (policy *RequirementPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
lockable_node, ok := node.(LockableNode)
if ok == false {
return false
}
lockable := lockable_node.LockableHandle()
for _, req := range(lockable.Requirements) {
if req.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
}
}
return false
} }
type ParentPolicy struct { return json.MarshalIndent(&struct{
SimplePolicy Policies map[string][]byte `json:"policies"`
}{
Policies: policies,
}, "", " ")
} }
func (policy *ParentPolicy) Type() NodeType { func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return NodeType("parent_policy") return nil
} }
func NewParentPolicy(id NodeID, actions NodeActions) ParentPolicy { func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
return ParentPolicy{ var j struct {
SimplePolicy: NewSimplePolicy(id, actions), Policies map[string][]byte `json:"policies"`
} }
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
} }
func (policy *ParentPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool { policies := map[PolicyType]Policy{}
thread_node, ok := node.(ThreadNode) acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext)
if ok == false { for name, ser := range(j.Policies) {
return false policy_def, exists := acl_ctx.Types[PolicyType(name)]
} if exists == false {
thread := thread_node.ThreadHandle() return nil, fmt.Errorf("%s is not a known policy type", name)
if thread.Owner != nil {
if thread.Owner.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
} }
policy, err := policy_def.Load(ctx, ser)
if err != nil {
return nil, err
} }
return false policies[PolicyType(name)] = policy
} }
type ChildrenPolicy struct { return ACLPolicyExtension{
SimplePolicy Policies: policies,
}, nil
} }
const ACLPolicyExtType = ExtType("ACL_POLICIES")
func (policy *ChildrenPolicy) Type() NodeType { func (ext ACLPolicyExtension) Type() ExtType {
return NodeType("children_policy") return ACLPolicyExtType
} }
func NewChildrenPolicy(id NodeID, actions NodeActions) ChildrenPolicy { // Check if the extension allows the principal to perform action on node
return ChildrenPolicy{ func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
SimplePolicy: NewSimplePolicy(id, actions), for _, policy := range(ext.Policies) {
if policy.Allows(context, principal, action, node) == true {
return true
} }
} }
func (policy *ChildrenPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
thread_node, ok := node.(ThreadNode)
if ok == false {
return false return false
} }
thread := thread_node.ThreadHandle()
for _, info := range(thread.Children) { func (policy AllNodesPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
if info.Child.ID() == principal.ID() { for _, a := range(policy.Actions) {
return policy.Actions.Allows(resource, action) if a == action {
return true
} }
} }
return false return false
} }
type UserOfPolicy struct {
SimplePolicy
Target GroupNode
}
type UserOfPolicyJSON struct {
SimplePolicyJSON
Target string `json:"target"`
}
func (policy *UserOfPolicy) Type() NodeType {
return NodeType("user_of_policy")
}
func (policy *UserOfPolicy) Serialize() ([]byte, error) {
target := ""
if policy.Target != nil {
target = policy.Target.ID().String()
}
return json.MarshalIndent(&UserOfPolicyJSON{
SimplePolicyJSON: NewSimplePolicyJSON(&policy.SimplePolicy),
Target: target,
}, "", " ")
}
func NewUserOfPolicy(id NodeID, actions NodeActions) UserOfPolicy {
return UserOfPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
Target: nil,
}
}
var LoadUserOfPolicy = LoadJSONNode(func(id NodeID, j UserOfPolicyJSON) (Node, error) {
policy := NewUserOfPolicy(id, j.Actions)
return &policy, nil
}, func(ctx *Context, policy *UserOfPolicy, j UserOfPolicyJSON, nodes NodeMap) error {
if j.Target != "" {
target_id, err := ParseID(j.Target)
if err != nil {
return err
}
target_node, err := LoadNodeRecurse(ctx, target_id, nodes)
if err != nil {
return err
}
target, ok := target_node.(GroupNode)
if ok == false {
return fmt.Errorf("%s is not a GroupNode", target_node.ID())
}
policy.Target = target
return nil
}
return RestoreSimpleNode(ctx, policy, j.SimpleNodeJSON, nodes)
})
func (policy *UserOfPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
if policy.Target != nil {
allowed := false
err := UseStates(context, policy.Target, NewLockInfo(policy.Target, []string{"users"}), func(context *StateContext) error {
for _, user := range(policy.Target.Users()) {
if user.ID() == principal.ID() {
allowed = policy.Actions.Allows(resource, action)
return nil
}
}
return nil
})
if err != nil {
return false
}
return allowed
}
return false
}

@ -8,24 +8,64 @@ import (
"encoding/json" "encoding/json"
) )
type ThreadExt struct {
Actions ThreadActions
Handlers ThreadHandlers
SignalChan chan GraphSignal
TimeoutChan <-chan time.Time
ChildWaits sync.WaitGroup
ActiveLock sync.Mutex
Active bool
StateName string
Parent *Node
Children map[NodeID]ChildInfo
ActionQueue []QueuedAction
NextAction *QueuedAction
}
func (ext *ThreadExt) Serialize() ([]byte, error) {
return nil, fmt.Errorf("NOT_IMPLEMENTED")
}
const ThreadExtType = ExtType("THREAD")
func (ext *ThreadExt) Type() ExtType {
return ThreadExtType
}
func (ext *ThreadExt) ChildList() []*Node {
ret := make([]*Node, len(ext.Children))
i := 0
for _, info := range(ext.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *Thread) Process(context *StateContext, signal GraphSignal) error { func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", thread.ID()) context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
var err error var err error
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
if thread.Parent != nil { if ext.Parent != nil {
return Signal(context, thread.Parent, thread, signal) return Signal(context, ext.Parent, node, signal)
} else { } else {
return nil return nil
} }
}) })
case Down: case Down:
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
for _, info := range(thread.Children) { for _, info := range(ext.Children) {
err := Signal(context, info.Child, thread, signal) err := Signal(context, info.Child, node, signal)
if err != nil { if err != nil {
return err return err
} }
@ -37,91 +77,121 @@ func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
default: default:
return fmt.Errorf("Invalid signal direction %d", signal.Direction()) return fmt.Errorf("Invalid signal direction %d", signal.Direction())
} }
ext.SignalChan <- signal
return err
}
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil { if err != nil {
return err return err
} }
thread.Chan <- signal child_ext, err := GetExt[*ThreadExt](child)
return thread.Lockable.Process(context, signal) if err != nil {
return err
} }
// Requires thread and childs thread to be locked for write return UpdateStates(context, principal, ACLMap{
func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error { thread.ID: ACLInfo{thread, []string{"children"}},
thread := node.ThreadHandle() child.ID: ACLInfo{child, []string{"parent"}},
child := child_node.ThreadHandle() }, func(context *StateContext) error {
_, is_child := thread.Children[child_node.ID()] _, is_child := thread_ext.Children[child.ID]
if is_child == false { if is_child == false {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
} }
child.Parent = nil delete(thread_ext.Children, child.ID)
delete(thread.Children, child.ID()) child_ext.Parent = nil
return nil return nil
})
} }
func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool { func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
for _, info := range(cur.ThreadHandle().Children) { for _, info := range(cur.Children) {
if info.Child.ID() == target.ID() { child := info.Child
return true if child.ID == id {
return true, nil
} }
is_child := false
UpdateStates(context, cur, NewLockMap( child_ext, err := GetExt[*ThreadExt](child)
NewLockInfo(info.Child, []string{"children"}), if err != nil {
), func(context *StateContext) error { return false, err
is_child = checkIfChild(context, target, info.Child) }
return nil
var is_child bool
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
is_child, err = checkIfChild(context, id, child_ext)
return err
}) })
if err != nil {
return false, err
}
if is_child { if is_child {
return true return true, nil
} }
} }
return false return false, nil
} }
// Links child to parent with info as the associated info // Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child // Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error { func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
if context == nil || thread_node == nil || info.Child == nil { if context == nil || principal == nil || thread == nil || info.Child == nil {
return fmt.Errorf("invalid input") return fmt.Errorf("invalid input")
} }
thread := thread_node.ThreadHandle()
child := info.Child.ThreadHandle()
child_node := info.Child
if thread.ID() == child.ID() { child := info.Child
return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) if thread.ID == child.ID {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
}
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
} }
return UpdateStates(context, princ, LockMap{ child_ext, err := GetExt[*ThreadExt](thread)
child.ID(): LockInfo{Node: child_node, Resources: []string{"parent"}}, if err != nil {
thread.ID(): LockInfo{Node: thread_node, Resources: []string{"children"}}, return err
}
return UpdateStates(context, principal, ACLMap{
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error { }, func(context *StateContext) error {
if child.Parent != nil { if child_ext.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
} }
if checkIfChild(context, thread, child) == true { is_child, err := checkIfChild(context, thread.ID, child_ext)
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
} }
if checkIfChild(context, child, thread) == true { is_child, err = checkIfChild(context, child.ID, thread_ext)
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
} }
// TODO check for info types // TODO check for info types
thread.Children[child.ID()] = info thread_ext.Children[child.ID] = info
child.Parent = thread_node child_ext.Parent = thread
return nil return nil
}) })
} }
type ThreadAction func(*Context, ThreadNode)(string, error) type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error) type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler type ThreadHandlers map[string]ThreadHandler
type InfoType string type InfoType string
@ -129,6 +199,10 @@ func (t InfoType) String() string {
return string(t) return string(t)
} }
type ThreadInfo interface {
Serializable[InfoType]
}
// Data required by a parent thread to restore it's children // Data required by a parent thread to restore it's children
type ParentThreadInfo struct { type ParentThreadInfo struct {
Start bool `json:"start"` Start bool `json:"start"`
@ -136,22 +210,23 @@ type ParentThreadInfo struct {
RestoreAction string `json:"restore_action"` RestoreAction string `json:"restore_action"`
} }
func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo { const ParentThreadInfoType = InfoType("PARENT")
return ParentThreadInfo{ func (info *ParentThreadInfo) Type() InfoType {
Start: start, return ParentThreadInfoType
StartAction: start_action,
RestoreAction: restore_action,
} }
func (info *ParentThreadInfo) Serialize() ([]byte, error) {
return json.MarshalIndent(info, "", " ")
} }
type ChildInfo struct { type ChildInfo struct {
Child ThreadNode Child *Node
Infos map[InfoType]interface{} Infos map[InfoType]ThreadInfo
} }
func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo { func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
if infos == nil { if infos == nil {
infos = map[InfoType]interface{}{} infos = map[InfoType]ThreadInfo{}
} }
return ChildInfo{ return ChildInfo{
@ -165,110 +240,21 @@ type QueuedAction struct {
Action string `json:"action"` Action string `json:"action"`
} }
type ThreadNode interface { func (ext *ThreadExt) QueueAction(end time.Time, action string) {
LockableNode ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ThreadHandle() *Thread ext.NextAction, ext.TimeoutChan = ext.SoonestAction()
}
type Thread struct {
Lockable
Actions ThreadActions
Handlers ThreadHandlers
TimeoutChan <-chan time.Time
Chan chan GraphSignal
ChildWaits sync.WaitGroup
Active bool
ActiveLock sync.Mutex
StateName string
Parent ThreadNode
Children map[NodeID]ChildInfo
InfoTypes []InfoType
ActionQueue []QueuedAction
NextAction *QueuedAction
}
func (thread *Thread) QueueAction(end time.Time, action string) {
thread.ActionQueue = append(thread.ActionQueue, QueuedAction{end, action})
thread.NextAction, thread.TimeoutChan = thread.SoonestAction()
}
func (thread *Thread) ClearActionQueue() {
thread.ActionQueue = []QueuedAction{}
thread.NextAction = nil
thread.TimeoutChan = nil
}
func (thread *Thread) ThreadHandle() *Thread {
return thread
}
func (thread *Thread) Type() NodeType {
return NodeType("thread")
} }
func (thread *Thread) Serialize() ([]byte, error) { func (ext *ThreadExt) ClearActionQueue() {
thread_json := NewThreadJSON(thread) ext.ActionQueue = []QueuedAction{}
return json.MarshalIndent(&thread_json, "", " ") ext.NextAction = nil
ext.TimeoutChan = nil
} }
func (thread *Thread) ChildList() []ThreadNode { func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) {
ret := make([]ThreadNode, len(thread.Children))
i := 0
for _, info := range(thread.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
type ThreadJSON struct {
LockableJSON
Parent string `json:"parent"`
Children map[string]map[string]interface{} `json:"children"`
ActionQueue []QueuedAction `json:"action_queue"`
StateName string `json:"state_name"`
InfoTypes []InfoType `json:"info_types"`
}
func NewThreadJSON(thread *Thread) ThreadJSON {
children := map[string]map[string]interface{}{}
for id, info := range(thread.Children) {
tmp := map[string]interface{}{}
for name, i := range(info.Infos) {
tmp[name.String()] = i
}
children[id.String()] = tmp
}
parent_id := ""
if thread.Parent != nil {
parent_id = thread.Parent.ID().String()
}
lockable_json := NewLockableJSON(&thread.Lockable)
return ThreadJSON{
Parent: parent_id,
Children: children,
ActionQueue: thread.ActionQueue,
StateName: thread.StateName,
LockableJSON: lockable_json,
InfoTypes: thread.InfoTypes,
}
}
var LoadThread = LoadJSONNode(func(id NodeID, j ThreadJSON) (Node, error) {
thread := NewThread(id, j.Name, j.StateName, j.InfoTypes, BaseThreadActions, BaseThreadHandlers)
return &thread, nil
}, RestoreThread)
func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction var soonest_action *QueuedAction
var soonest_time time.Time var soonest_time time.Time
for _, action := range(thread.ActionQueue) { for _, action := range(ext.ActionQueue) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil { if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action soonest_action = &action
soonest_time = action.Timeout soonest_time = action.Timeout
@ -281,55 +267,6 @@ func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
} }
} }
func RestoreThread(ctx *Context, thread ThreadNode, j ThreadJSON, nodes NodeMap) error {
thread_ptr := thread.ThreadHandle()
thread_ptr.ActionQueue = j.ActionQueue
thread_ptr.NextAction, thread_ptr.TimeoutChan = thread_ptr.SoonestAction()
if j.Parent != "" {
parent_id, err := ParseID(j.Parent)
if err != nil {
return err
}
p, err := LoadNodeRecurse(ctx, parent_id, nodes)
if err != nil {
return err
}
p_t, ok := p.(ThreadNode)
if ok == false {
return err
}
thread_ptr.Parent = p_t
}
for id_str, info_raw := range(j.Children) {
id, err := ParseID(id_str)
if err != nil {
return err
}
child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
}
child_t, ok := child_node.(ThreadNode)
if ok == false {
return fmt.Errorf("%+v is not a Thread as expected", child_node)
}
parsed_info, err := DeserializeChildInfo(ctx, info_raw)
if err != nil {
return err
}
thread_ptr.Children[id] = ChildInfo{child_t, parsed_info}
}
return RestoreLockable(ctx, thread, j.LockableJSON, nodes)
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) { var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) { "parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{}) m, ok := raw.(map[string]interface{})
@ -357,141 +294,133 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) {
}, },
} }
func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) { func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
ret := map[InfoType]interface{}{} return ThreadExt{
for type_str, info_raw := range(infos_raw) {
info_type := InfoType(type_str)
deserializer, exists := deserializers[info_type]
if exists == false {
return nil, fmt.Errorf("No deserializer for %s", info_type)
}
var err error
ret[info_type], err = deserializer(info_raw)
if err != nil {
return nil, err
}
}
return ret, nil
}
const THREAD_SIGNAL_BUFFER_SIZE = 128
func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread {
return Thread{
Lockable: NewLockable(id, name),
InfoTypes: info_types,
StateName: state_name,
Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
Children: map[NodeID]ChildInfo{},
Actions: actions, Actions: actions,
Handlers: handlers, Handlers: handlers,
SignalChan: make(chan GraphSignal, buffer),
TimeoutChan: nil,
Active: false,
StateName: state_name,
Parent: nil,
Children: map[NodeID]ChildInfo{},
ActionQueue: []QueuedAction{},
NextAction: nil,
} }
} }
func (thread *Thread) SetActive(active bool) error { func (ext *ThreadExt) SetActive(active bool) error {
thread.ActiveLock.Lock() ext.ActiveLock.Lock()
defer thread.ActiveLock.Unlock() defer ext.ActiveLock.Unlock()
if thread.Active == true && active == true { if ext.Active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID()) return fmt.Errorf("alreday active, cannot set active")
} else if thread.Active == false && active == false { } else if ext.Active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID()) return fmt.Errorf("already inactive, canot set inactive")
} }
thread.Active = active ext.Active = active
return nil return nil
} }
func (thread *Thread) SetState(state string) error { func (ext *ThreadExt) SetState(state string) error {
thread.StateName = state ext.StateName = state
return nil return nil
} }
// Requires the read permission of threads children // Requires the read permission of threads children
func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode { func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
if node == nil { if thread == nil {
panic("cannot recurse through nil") panic("cannot recurse through nil")
} }
thread := node.ThreadHandle()
if id == thread.ID() { if id == thread.ID {
return thread return thread, nil
} }
for _, info := range thread.Children { thread_ext, err := GetExt[*ThreadExt](thread)
var result ThreadNode if err != nil {
UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error { return nil, err
result = FindChild(context, princ, info.Child, id) }
var found *Node = nil
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, info := range(thread_ext.Children) {
found, err = FindChild(context, principal, info.Child, id)
if err != nil {
return err
}
if found != nil {
return nil return nil
})
if result != nil {
return result
} }
} }
return nil return nil
})
return found, err
} }
func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) { func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
thread.ChildWaits.Add(1) thread_ext.ChildWaits.Add(1)
go func(child ThreadNode) { go func(child *Node) {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) defer thread_ext.ChildWaits.Done()
defer thread.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action) err := ThreadLoop(ctx, child, first_action)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID(), err) ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
} else { } else {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID()) ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
} }
}(child) }(child)
} }
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context // Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error { func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
// Start the thread, error if double-started thread_ext, err := GetExt[*ThreadExt](thread)
thread := node.ThreadHandle() if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) return err
err := thread.SetActive(true) }
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
err = thread_ext.SetActive(true)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err) ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
return err return err
} }
next_action := first_action next_action := first_action
for next_action != "" { for next_action != "" {
action, exists := thread.Actions[next_action] action, exists := thread_ext.Actions[next_action]
if exists == false { if exists == false {
error_str := fmt.Sprintf("%s is not a valid action", next_action) error_str := fmt.Sprintf("%s is not a valid action", next_action)
return errors.New(error_str) return errors.New(error_str)
} }
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action) ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
next_action, err = action(ctx, node) next_action, err = action(ctx, thread, thread_ext)
if err != nil { if err != nil {
return err return err
} }
} }
err = thread.SetActive(false) err = thread_ext.SetActive(false)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err) ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
return err return err
} }
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID()) ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
return nil return nil
} }
func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
thread := node.ThreadHandle() ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockMap( err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
NewLockInfo(node, []string{"children"}),
), func(context *StateContext) error {
sig, ok := signal.(IDSignal) sig, ok := signal.(IDSignal)
if ok == false { if ok == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST") ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
return nil return nil
} }
info, exists := thread.Children[sig.ID] info, exists := thread_ext.Children[sig.ID]
if exists == false { if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil return nil
@ -502,7 +431,7 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin
} }
if parent_info.Start == true { if parent_info.Start == true {
ChildGo(ctx, thread, info.Child, parent_info.StartAction) ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
} }
return nil return nil
}) })
@ -517,28 +446,26 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin
// Helper function to start a child from a thread during a signal handler // Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) { func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
sig, ok := signal.(StartChildSignal) sig, ok := signal.(StartChildSignal)
if ok == false { if ok == false {
return "wait", nil return "wait", nil
} }
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return "wait", UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
info, exists:= thread.Children[sig.ID] info, exists:= thread_ext.Children[sig.ID]
if exists == false { if exists == false {
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID()) return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
} }
return UpdateStates(context, node, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error { return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error {
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
if exists == false { if exists == false {
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
} }
parent_info.Start = true parent_info.Start = true
ChildGo(ctx, thread, info.Child, sig.Action) ChildGo(ctx, thread_ext, info.Child, sig.Action)
return nil return nil
}) })
@ -547,19 +474,23 @@ func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string
// Helper function to restore threads that should be running from a parents restore action // Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, node ThreadNode, start bool) error { func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, node, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error { return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"start", "state"}), func(context *StateContext) error {
for _, info := range(thread.Children) { for _, info := range(thread_ext.Children) {
child_ext, err := GetExt[*ThreadExt](info.Child)
if err != nil {
return err
}
parent_info := info.Infos["parent"].(*ParentThreadInfo) parent_info := info.Infos["parent"].(*ParentThreadInfo)
if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" { if parent_info.Start == true && child_ext.StateName != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID()) ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true { if start == true {
ChildGo(ctx, thread, info.Child, parent_info.StartAction) ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
} else { } else {
ChildGo(ctx, thread, info.Child, parent_info.RestoreAction) ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
} }
} }
} }
@ -571,73 +502,70 @@ func ThreadRestore(ctx * Context, node ThreadNode, start bool) error {
// Helper function to be called during a threads start action, sets the thread state to started // Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
// Returns "wait", nil on success, so the first return value can be ignored safely // Returns "wait", nil on success, so the first return value can be ignored safely
func ThreadStart(ctx * Context, node ThreadNode) (string, error) { func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node) err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
if err != nil { if err != nil {
return err return err
} }
return thread.SetState("started") return thread_ext.SetState("started")
}) })
if err != nil { if err != nil {
return "", err return "", err
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
return "wait", Signal(context, node, node, NewStatusSignal("started", node.ID())) return "wait", Signal(context, thread, thread, NewStatusSignal("started", thread.ID))
} }
func ThreadWait(ctx * Context, node ThreadNode) (string, error) { func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
thread := node.ThreadHandle() ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID(), thread.ActionQueue)
for { for {
select { select {
case signal := <- thread.Chan: case signal := <- thread_ext.SignalChan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
signal_fn, exists := thread.Handlers[signal.Type()] signal_fn, exists := thread_ext.Handlers[signal.Type()]
if exists == true { if exists == true {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
return signal_fn(ctx, node, signal) return signal_fn(ctx, thread, thread_ext, signal)
} else { } else {
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
} }
case <- thread.TimeoutChan: case <- thread_ext.TimeoutChan:
timeout_action := "" timeout_action := ""
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"timeout"})), func(context *StateContext) error { err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.NextAction.Action timeout_action = thread_ext.NextAction.Action
thread.NextAction, thread.TimeoutChan = thread.SoonestAction() thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction()
return nil return nil
}) })
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID(), err) ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
} }
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), timeout_action) ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
return timeout_action, nil return timeout_action, nil
} }
} }
} }
func ThreadFinish(ctx *Context, node ThreadNode) (string, error) { func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
thread := node.ThreadHandle()
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return "", UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished") err := thread_ext.SetState("finished")
if err != nil { if err != nil {
return err return err
} }
return UnlockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node) return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
}) })
} }
var ThreadAbortedError = errors.New("Thread aborted by signal") var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, node, node, NewStatusSignal("aborted", node.ID())) err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -645,9 +573,9 @@ func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, er
} }
// Default thread action for "stop", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, node, node, NewStatusSignal("stopped", node.ID())) err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
return "finish", err return "finish", err
} }