Moved from inheritance to extensions

graph-rework-2
noah metz 2023-07-25 21:43:15 -06:00
parent b3f6ea67c9
commit ff813d6c2b
6 changed files with 966 additions and 1414 deletions

@ -1,31 +1,15 @@
package graphvent
import (
"github.com/graphql-go/graphql"
badger "github.com/dgraph-io/badger/v3"
"reflect"
"fmt"
)
// NodeLoadFunc is the footprint of the function used to create a new node in memory from persisted bytes
type NodeLoadFunc func(*Context, NodeID, []byte, NodeMap)(Node, error)
// A NodeDef is a description of a node that can be added to a Context
type NodeDef struct {
Load NodeLoadFunc
Type NodeType
GQLType *graphql.Object
Reflect reflect.Type
}
// Create a new Node def, extracting the Type and Reflect from example
func NewNodeDef(example Node, load_func NodeLoadFunc, gql_type *graphql.Object) NodeDef {
return NodeDef{
Type: example.Type(),
Load: load_func,
GQLType: gql_type,
Reflect: reflect.TypeOf(example),
}
type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
type ExtensionInfo struct {
Load ExtensionLoadFunc
Type ExtType
Data interface{}
}
// A Context is all the data needed to run a graphvent
@ -34,211 +18,55 @@ type Context struct {
DB * badger.DB
// Log is an interface used to record events happening
Log Logger
// A mapping between type hashes and their corresponding node definitions
Types map[uint64]NodeDef
// GQL substructure
GQL GQLContext
}
// Recreate the GQL schema after making changes
func (ctx * Context) RebuildSchema() error {
schemaConfig := graphql.SchemaConfig{
Types: ctx.GQL.TypeList,
Query: ctx.GQL.Query,
Mutation: ctx.GQL.Mutation,
Subscription: ctx.GQL.Subscription,
}
schema, err := graphql.NewSchema(schemaConfig)
if err != nil {
return err
}
ctx.GQL.Schema = schema
return nil
// A mapping between type hashes and their corresponding extension definitions
Extensions map[uint64]ExtensionInfo
// All loaded Nodes
Nodes map[NodeID]*Node
}
// Add a non-node type to the gql context
func (ctx * Context) AddGQLType(gql_type graphql.Type) {
ctx.GQL.TypeList = append(ctx.GQL.TypeList, gql_type)
func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
type_hash := ext_type.Hash()
ext, _ := ctx.Extensions[type_hash]
return ext
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx * Context) RegisterNodeType(def NodeDef) error {
if def.Load == nil {
return fmt.Errorf("Cannot register a node without a load function: %s", def.Type)
}
if def.Reflect == nil {
return fmt.Errorf("Cannot register a node without a reflect type: %s", def.Type)
}
if def.GQLType == nil {
return fmt.Errorf("Cannot register a node without a gql type: %s", def.Type)
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error {
if load_fn == nil {
return fmt.Errorf("def has no load function")
}
type_hash := def.Type.Hash()
_, exists := ctx.Types[type_hash]
type_hash := ext_type.Hash()
_, exists := ctx.Extensions[type_hash]
if exists == true {
return fmt.Errorf("Cannot register node of type %s, type already exists in context", def.Type)
return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type)
}
ctx.Types[type_hash] = def
node_type := reflect.TypeOf((*Node)(nil)).Elem()
lockable_type := reflect.TypeOf((*LockableNode)(nil)).Elem()
thread_type := reflect.TypeOf((*ThreadNode)(nil)).Elem()
if def.Reflect.Implements(node_type) {
ctx.GQL.ValidNodes[def.Reflect] = def.GQLType
}
if def.Reflect.Implements(lockable_type) {
ctx.GQL.ValidLockables[def.Reflect] = def.GQLType
}
if def.Reflect.Implements(thread_type) {
ctx.GQL.ValidThreads[def.Reflect] = def.GQLType
ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn,
Type: ext_type,
}
ctx.GQL.TypeList = append(ctx.GQL.TypeList, def.GQLType)
return nil
}
// Map of go types to graphql types
type ObjTypeMap map[reflect.Type]*graphql.Object
// GQL Specific Context information
type GQLContext struct {
// Generated GQL schema
Schema graphql.Schema
// List of GQL types
TypeList []graphql.Type
// Interface type maps to map go types of specific interfaces to gql types
ValidNodes ObjTypeMap
ValidLockables ObjTypeMap
ValidThreads ObjTypeMap
BaseNodeType *graphql.Object
BaseLockableType *graphql.Object
BaseThreadType *graphql.Object
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
}
// Create a new GQL context without any content
func NewGQLContext() GQLContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
ctx := GQLContext{
Schema: graphql.Schema{},
TypeList: []graphql.Type{},
ValidNodes: ObjTypeMap{},
ValidThreads: ObjTypeMap{},
ValidLockables: ObjTypeMap{},
Query: query,
Mutation: mutation,
Subscription: subscription,
BaseNodeType: GQLTypeSimpleNode.Type,
BaseLockableType: GQLTypeSimpleLockable.Type,
BaseThreadType: GQLTypeSimpleThread.Type,
}
return ctx
}
// Create a new Context with all the library content added
func NewContext(db * badger.DB, log Logger) * Context {
func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{
GQL: NewGQLContext(),
DB: db,
Log: log,
Types: map[uint64]NodeDef{},
Extensions: map[uint64]ExtensionInfo{},
Nodes: map[NodeID]*Node{},
}
err := ctx.RegisterNodeType(NewNodeDef((*SimpleNode)(nil), LoadSimpleNode, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Lockable)(nil), LoadLockable, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Listener)(nil), LoadListener, GQLTypeSimpleLockable.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Thread)(nil), LoadThread, GQLTypeSimpleThread.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*GQLThread)(nil), LoadGQLThread, GQLTypeGQLThread.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*User)(nil), LoadUser, GQLTypeUser.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*Group)(nil), LoadGroup, GQLTypeSimpleLockable.Type))
err := ctx.RegisterExtension(ACLExtType, LoadACLExtension)
if err != nil {
panic(err)
return nil, err
}
err = ctx.RegisterNodeType(NewNodeDef((*PerNodePolicy)(nil), LoadPerNodePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*SimplePolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*DependencyPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*ParentPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*ChildrenPolicy)(nil), LoadSimplePolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
err = ctx.RegisterNodeType(NewNodeDef((*UserOfPolicy)(nil), LoadUserOfPolicy, GQLTypeSimpleNode.Type))
if err != nil {
panic(err)
}
ctx.AddGQLType(GQLTypeSignal.Type)
ctx.GQL.Query.AddFieldConfig("Self", GQLQuerySelf)
ctx.GQL.Query.AddFieldConfig("User", GQLQueryUser)
ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate)
ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf)
ctx.GQL.Mutation.AddFieldConfig("abort", GQLMutationAbort)
ctx.GQL.Mutation.AddFieldConfig("startChild", GQLMutationStartChild)
err = ctx.RebuildSchema()
err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension)
if err != nil {
panic(err)
return nil, err
}
return ctx
return ctx, nil
}

@ -58,7 +58,9 @@ func logTestContext(t * testing.T, components []string) * Context {
t.Fatal(err)
}
return NewContext(db, NewConsoleLogger(components))
ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err)
return ctx
}
func testContext(t * testing.T) * Context {
@ -67,7 +69,9 @@ func testContext(t * testing.T) * Context {
t.Fatal(err)
}
return NewContext(db, NewConsoleLogger([]string{}))
ctx, err := NewContext(db, NewConsoleLogger([]string{}))
fatalErr(t, err)
return ctx
}
func fatalErr(t * testing.T, err error) {

@ -2,171 +2,121 @@ package graphvent
import (
"fmt"
"reflect"
"encoding/json"
)
type Listener struct {
Lockable
type ListenerExt struct {
Chan chan GraphSignal
}
func (node *Listener) Type() NodeType {
return NodeType("listener")
func NewListenerExt(buffer int) ListenerExt {
return ListenerExt{
Chan: make(chan GraphSignal, buffer),
}
func (node *Listener) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s", node.ID())
select {
case node.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW: %s - %s", node.ID(), signal)
}
return node.Lockable.Process(context, signal)
}
const LISTENER_CHANNEL_BUFFER = 1024
func NewListener(id NodeID, name string) Listener {
return Listener{
Lockable: NewLockable(id, name),
Chan: make(chan GraphSignal, LISTENER_CHANNEL_BUFFER),
}
const ListenerExtType = ExtType("LISTENER")
func (listener ListenerExt) Type() ExtType {
return ListenerExtType
}
var LoadListener = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) {
listener := NewListener(id, j.Name)
return &listener, nil
}, RestoreLockable)
type LockableNode interface {
Node
LockableHandle() *Lockable
}
// Lockable is a simple Lockable implementation that can be embedded into more complex structures
type Lockable struct {
SimpleNode
Name string
Owner LockableNode
Requirements map[NodeID]LockableNode
Dependencies map[NodeID]LockableNode
LocksHeld map[NodeID]LockableNode
func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error {
select {
case ext.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
}
func (lockable *Lockable) LockableHandle() *Lockable {
return lockable
return nil
}
func (lockable *Lockable) Type() NodeType {
return NodeType("lockable")
func (node ListenerExt) Serialize() ([]byte, error) {
return []byte{}, nil
}
type LockableJSON struct {
SimpleNodeJSON
Name string `json:"name"`
Owner string `json:"owner"`
Dependencies []string `json:"dependencies"`
Requirements []string `json:"requirements"`
LocksHeld map[string]string `json:"locks_held"`
type LockableExt struct {
Owner *Node
Requirements map[NodeID]*Node
Dependencies map[NodeID]*Node
LocksHeld map[NodeID]*Node
}
func (lockable *Lockable) Serialize() ([]byte, error) {
lockable_json := NewLockableJSON(lockable)
return json.MarshalIndent(&lockable_json, "", " ")
const LockableExtType = ExtType("LOCKABLE")
func (ext *LockableExt) Type() ExtType {
return LockableExtType
}
func NewLockableJSON(lockable *Lockable) LockableJSON {
requirement_ids := make([]string, len(lockable.Requirements))
func (ext *LockableExt) Serialize() ([]byte, error) {
requirements := make([]string, len(ext.Requirements))
req_n := 0
for id, _ := range(lockable.Requirements) {
requirement_ids[req_n] = id.String()
for id, _ := range(ext.Requirements) {
requirements[req_n] = id.String()
req_n++
}
dependency_ids := make([]string, len(lockable.Dependencies))
dependencies := make([]string, len(ext.Dependencies))
dep_n := 0
for id, _ := range(lockable.Dependencies) {
dependency_ids[dep_n] = id.String()
for id, _ := range(ext.Dependencies) {
dependencies[dep_n] = id.String()
dep_n++
}
owner_id := ""
if lockable.Owner != nil {
owner_id = lockable.Owner.ID().String()
owner := ""
if ext.Owner != nil {
owner = ext.Owner.ID.String()
}
locks_held := map[string]string{}
for lockable_id, node := range(lockable.LocksHeld) {
for lockable_id, node := range(ext.LocksHeld) {
if node == nil {
locks_held[lockable_id.String()] = ""
} else {
locks_held[lockable_id.String()] = node.ID().String()
locks_held[lockable_id.String()] = node.ID.String()
}
}
node_json := NewSimpleNodeJSON(&lockable.SimpleNode)
return LockableJSON{
SimpleNodeJSON: node_json,
Name: lockable.Name,
Owner: owner_id,
Dependencies: dependency_ids,
Requirements: requirement_ids,
return json.MarshalIndent(&struct{
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}{
Owner: owner,
Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held,
}
}
func (lockable *Lockable) RecordUnlock(l LockableNode) LockableNode {
lockable_id := l.ID()
last_owner, exists := lockable.LocksHeld[lockable_id]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(lockable.LocksHeld, lockable_id)
return last_owner
}
func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) {
lockable_id := l.ID()
_, exists := lockable.LocksHeld[lockable_id]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
lockable.LocksHeld[lockable_id] = last_owner
}, "", " ")
}
// Assumed that lockable is already locked for signal
func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", lockable.ID())
func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
var err error
switch signal.Direction() {
case Up:
err = UseStates(context, lockable,
NewLockInfo(lockable, []string{"dependencies", "owner"}), func(context *StateContext) error {
err = UseStates(context, node,
NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false
for _, dependency := range(lockable.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", lockable.ID(), dependency.ID())
Signal(context, dependency, lockable, signal)
if lockable.Owner != nil {
if dependency.ID() == lockable.Owner.ID() {
for _, dependency := range(ext.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
Signal(context, dependency, node, signal)
if ext.Owner != nil {
if dependency.ID == ext.Owner.ID {
owner_sent = true
}
}
}
if lockable.Owner != nil && owner_sent == false {
if lockable.Owner.ID() != lockable.ID() {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", lockable.ID(), lockable.Owner.ID())
return Signal(context, lockable.Owner, lockable, signal)
if ext.Owner != nil && owner_sent == false {
if ext.Owner.ID != node.ID {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
return Signal(context, ext.Owner, node, signal)
}
}
return nil
})
case Down:
err = UseStates(context, lockable, NewLockInfo(lockable, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(lockable.Requirements) {
err := Signal(context, requirement, lockable, signal)
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(ext.Requirements) {
err := Signal(context, requirement, node, signal)
if err != nil {
return err
}
@ -176,112 +126,154 @@ func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) err
case Direct:
err = nil
default:
return fmt.Errorf("invalid signal direction %d", signal.Direction())
err = fmt.Errorf("invalid signal direction %d", signal.Direction())
}
if err != nil {
return err
}
return lockable.SimpleNode.Process(context, signal)
return nil
}
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
last_owner, exists := ext.LocksHeld[node.ID]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(ext.LocksHeld, node.ID)
return last_owner
}
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
_, exists := ext.LocksHeld[node.ID]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
ext.LocksHeld[node.ID] = last_owner
}
// Removes requirement as a requirement from lockable
// Continues the write context with princ, getting requirents for lockable and dependencies for requirement
// Assumes that an active write context exists with princ locked so that princ's state can be used in checks
func UnlinkLockables(context *StateContext, princ Node, lockable LockableNode, requirement LockableNode) error {
return UpdateStates(context, princ, LockMap{
lockable.ID(): LockInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID(): LockInfo{Node: requirement, Resources: []string{"dependencies"}},
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
requirement_ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
return UpdateStates(context, princ, ACLMap{
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found Node = nil
for _, req := range(lockable.LockableHandle().Requirements) {
if requirement.ID() == req.ID() {
var found *Node = nil
for _, req := range(lockable_ext.Requirements) {
if requirement.ID == req.ID {
found = req
break
}
}
if found == nil {
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID())
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
}
delete(requirement.LockableHandle().Dependencies, lockable.ID())
delete(lockable.LockableHandle().Requirements, requirement.ID())
delete(requirement_ext.Dependencies, lockable.ID)
delete(lockable_ext.Requirements, requirement.ID)
return nil
})
}
// Link requirements as requirements to lockable
// Continues the wrtie context with princ, getting requirements for lockable and dependencies for requirements
func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode, requirements []LockableNode) error {
if lockable_node == nil {
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
if lockable == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
}
lockable := lockable_node.LockableHandle()
if len(requirements) == 0 {
return nil
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
}
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
found := map[NodeID]bool{}
req_exts := map[NodeID]*LockableExt{}
for _, requirement := range(requirements) {
if requirement == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
}
if lockable.ID() == requirement.ID() {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID())
if lockable.ID == requirement.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
}
_, exists := found[requirement.ID()]
_, exists := req_exts[requirement.ID]
if exists == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID())
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
}
ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
found[requirement.ID()] = true
req_exts[requirement.ID] = ext
}
return UpdateStates(context, princ, NewLockMap(
NewLockInfo(lockable_node, []string{"requirements"}),
LockList(requirements, []string{"dependencies"}),
return UpdateStates(context, princ, NewACLMap(
NewACLInfo(lockable, []string{"requirements"}),
ACLList(requirements, []string{"dependencies"}),
), func(context *StateContext) error {
// Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement_node := range(requirements) {
requirement := requirement_node.LockableHandle()
for _, req_node := range(requirements) {
req := req_node.LockableHandle()
if req.ID() == requirement.ID() {
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
for _, req := range(requirements) {
if req.ID == requirement.ID {
continue
}
if checkIfRequirement(context, req, requirement) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID())
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
}
}
if checkIfRequirement(context, lockable, requirement) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID())
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
}
if checkIfRequirement(context, requirement, lockable) == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID())
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
}
if lockable.Owner == nil {
if lockable_ext.Owner == nil {
// If the new owner isn't locked, we can add the requirement
} else if requirement.Owner == nil {
} else if requirement_ext.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID())
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
} else {
// If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement.Owner.ID() != lockable.Owner.ID() {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID())
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
}
}
}
// Update the states of the requirements
for _, requirement_node := range(requirements) {
requirement := requirement_node.LockableHandle()
requirement.Dependencies[lockable.ID()] = lockable_node
lockable.Requirements[lockable.ID()] = requirement_node
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID())
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
requirement_ext.Dependencies[lockable.ID] = lockable
lockable_ext.Requirements[lockable.ID] = requirement
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
}
// Return no error
@ -289,74 +281,91 @@ func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode
})
}
// Must be called withing update context
func checkIfRequirement(context *StateContext, r LockableNode, cur LockableNode) bool {
for _, c := range(cur.LockableHandle().Requirements) {
if c.ID() == r.ID() {
return true
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
for _, req := range(cur.Requirements) {
if req.ID == id {
return true, nil
}
is_requirement := false
UpdateStates(context, cur, NewLockMap(NewLockInfo(c, []string{"requirements"})), func(context *StateContext) error {
is_requirement = checkIfRequirement(context, cur, c)
return nil
})
if is_requirement {
return true
req_ext, err := GetExt[*LockableExt](req)
if err != nil {
return false, err
}
var is_req bool
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
is_req, err = checkIfRequirement(context, id, req_ext)
return err
})
if err != nil {
return false, err
}
if is_req == true {
return true, nil
}
}
return false
return false, nil
}
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_owner_node LockableNode) error {
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided")
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
}
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_lock) {
var err error
if l == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
}
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
if new_owner_node == nil {
if new_owner == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
}
new_owner := new_owner_node.LockableHandle()
new_owner_ext, err := GetExt[*LockableExt](new_owner)
if err != nil {
return err
}
// Called with no requirements to lock, success
if len(to_lock) == 0 {
return nil
}
return UpdateStates(context, new_owner, NewLockMap(
LockListM(to_lock, []string{"lock"}),
NewLockInfo(new_owner, nil),
return UpdateStates(context, new_owner, NewACLMap(
ACLListM(to_lock, []string{"lock"}),
NewACLInfo(new_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req_node := range(to_lock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), new_owner.ID())
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
// If req is alreay locked, check that we can pass the lock
if req.Owner != nil {
owner := req.Owner
if owner.ID() == new_owner.ID() {
if req_ext.Owner != nil {
owner := req_ext.Owner
if owner.ID == new_owner.ID {
continue
} else {
err := UpdateStates(context, new_owner, NewLockInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req.Requirements, req)
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req_ext.Requirements, req)
})
if err != nil {
return err
}
}
} else {
err := LockLockables(context, req.Requirements, req)
err := LockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
@ -364,22 +373,22 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o
}
// At this point state modification will be started, so no errors can be returned
for _, req_node := range(to_lock) {
req := req_node.LockableHandle()
old_owner := req.Owner
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
old_owner := req_ext.Owner
// If the lockable was previously unowned, update the state
if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID(), req.ID())
req.Owner = new_owner_node
new_owner.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID() == new_owner.ID() {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID(), req.ID())
} else if old_owner.ID == new_owner.ID {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
// Otherwise update the state
} else {
req.Owner = new_owner
new_owner.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID(), req.ID(), old_owner.ID())
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
}
}
return nil
@ -387,61 +396,72 @@ func LockLockables(context *StateContext, to_lock map[NodeID]LockableNode, new_o
}
func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, old_owner_node LockableNode) error {
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
}
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_unlock) {
if l == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
}
var err error
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
if old_owner_node == nil {
if old_owner == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
}
old_owner := old_owner_node.LockableHandle()
old_owner_ext, err := GetExt[*LockableExt](old_owner)
if err != nil {
return err
}
// Called with no requirements to unlock, success
if len(to_unlock) == 0 {
return nil
}
return UpdateStates(context, old_owner, NewLockMap(
LockListM(to_unlock, []string{"lock"}),
NewLockInfo(old_owner, nil),
return UpdateStates(context, old_owner, NewACLMap(
ACLListM(to_unlock, []string{"lock"}),
NewACLInfo(old_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), old_owner.ID())
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
// Check if the owner is correct
if req.Owner != nil {
if req.Owner.ID() != old_owner.ID() {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), old_owner.ID())
if req_ext.Owner != nil {
if req_ext.Owner.ID != old_owner.ID {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
}
} else {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID())
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
}
err := UnlockLockables(context, req.Requirements, req)
err := UnlockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
}
// At this point state modification will be started, so no errors can be returned
for _, req_node := range(to_unlock) {
req := req_node.LockableHandle()
new_owner := old_owner.RecordUnlock(req)
req.Owner = new_owner
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
new_owner := old_owner_ext.RecordUnlock(req)
req_ext.Owner = new_owner
if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID(), req.ID())
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
} else {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID(), req.ID(), new_owner.ID())
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
}
}
@ -449,103 +469,55 @@ func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, o
})
}
var LoadLockable = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) {
lockable := NewLockable(id, j.Name)
return &lockable, nil
}, RestoreLockable)
func NewLockable(id NodeID, name string) Lockable {
return Lockable{
SimpleNode: NewSimpleNode(id),
Name: name,
Owner: nil,
Requirements: map[NodeID]LockableNode{},
Dependencies: map[NodeID]LockableNode{},
LocksHeld: map[NodeID]LockableNode{},
}
}
// Helper function to load links when loading a struct that embeds Lockable
func RestoreLockable(ctx * Context, lockable LockableNode, j LockableJSON, nodes NodeMap) error {
lockable_ptr := lockable.LockableHandle()
if j.Owner != "" {
owner_id, err := ParseID(j.Owner)
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
id, err := ParseID(id_str)
if err != nil {
return err
}
owner_node, err := LoadNodeRecurse(ctx, owner_id, nodes)
if err != nil {
return err
}
owner, ok := owner_node.(LockableNode)
if ok == false {
return fmt.Errorf("%s is not a Lockable", j.Owner)
}
lockable_ptr.Owner = owner
return nil, err
}
for _, dep_str := range(j.Dependencies) {
dep_id, err := ParseID(dep_str)
if err != nil {
return err
return LoadNode(ctx, id)
}
dep_node, err := LoadNodeRecurse(ctx, dep_id, nodes)
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
nodes := NodeMap{}
for id_str_1, id_str_2 := range(ids) {
id_1, err := ParseID(id_str_1)
if err != nil {
return err
}
dep, ok := dep_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
}
ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep))
lockable_ptr.Dependencies[dep_id] = dep
return nil, err
}
for _, req_str := range(j.Requirements) {
req_id, err := ParseID(req_str)
id_2, err := ParseID(id_str_2)
if err != nil {
return err
return nil, err
}
req_node, err := LoadNodeRecurse(ctx, req_id, nodes)
node_1, err := LoadNode(ctx, id_1)
if err != nil {
return err
}
req, ok := req_node.(LockableNode)
if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", req_node)
}
lockable_ptr.Requirements[req_id] = req
return nil, err
}
for l_id_str, h_str := range(j.LocksHeld) {
l_id, err := ParseID(l_id_str)
l, err := LoadNodeRecurse(ctx, l_id, nodes)
node_2, err := LoadNode(ctx, id_2)
if err != nil {
return err
return nil, err
}
l_l, ok := l.(LockableNode)
if ok == false {
return fmt.Errorf("%s is not a Lockable", l.ID())
nodes[node_1.ID] = node_2
}
var h_l LockableNode
if h_str != "" {
h_id, err := ParseID(h_str)
if err != nil {
return err
return nodes, nil
}
h_node, err := LoadNodeRecurse(ctx, h_id, nodes)
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
nodes := NodeMap{}
for _, id_str := range(ids) {
node, err := RestoreNode(ctx, id_str)
if err != nil {
return err
return nil, err
}
h, ok := h_node.(LockableNode)
if ok == false {
return err
}
h_l = h
}
lockable_ptr.RecordLock(l_l, h_l)
nodes[node.ID] = node
}
return RestoreSimpleNode(ctx, lockable, j.SimpleNodeJSON, nodes)
return nodes, nil
}

@ -2,6 +2,7 @@ package graphvent
import (
"sync"
"reflect"
"github.com/google/uuid"
badger "github.com/dgraph-io/badger/v3"
"fmt"
@ -31,6 +32,12 @@ func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
// Ignore the error since we're enforcing 16 byte length at compile time
func IDFromBytes(bytes [16]byte) NodeID {
id, _ := uuid.FromBytes(bytes[:])
return NodeID(id)
}
func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str)
if err != nil {
@ -45,230 +52,291 @@ func KeyID(pub *ecdsa.PublicKey) NodeID {
return NodeID(str)
}
// Types are how nodes are associated with structs at runtime(and from the DB)
type NodeType string
func (node_type NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", node_type)))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
// Generate a random NodeID
func RandID() NodeID {
return NodeID(uuid.New())
}
type Node interface {
ID() NodeID
Type() NodeType
type Serializable[I comparable] interface {
Type() I
Serialize() ([]byte, error)
LockState(write bool)
UnlockState(write bool)
Process(context *StateContext, signal GraphSignal) error
Policies() []Policy
NodeHandle() *SimpleNode
}
type SimpleNode struct {
Id NodeID
state_mutex sync.RWMutex
PolicyMap map[NodeID]Policy
// NodeExtensions are additional data that can be attached to nodes, and used in node functions
type Extension interface {
Serializable[ExtType]
// Send a signal to this extension to process,
// this typically triggers signals to be sent to nodes linked in the extension
Process(context *StateContext, node *Node, signal GraphSignal) error
}
func (node *SimpleNode) NodeHandle() *SimpleNode {
return node
// Nodes represent an addressible group of extensions
type Node struct {
ID NodeID
Lock sync.RWMutex
ExtensionMap map[ExtType]Extension
}
func NewSimpleNode(id NodeID) SimpleNode {
return SimpleNode{
Id: id,
PolicyMap: map[NodeID]Policy{},
}
func GetExt[T Extension](node *Node) (T, error) {
var zero T
ext_type := zero.Type()
ext, exists := node.ExtensionMap[ext_type]
if exists == false {
return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
}
type SimpleNodeJSON struct {
Policies []string `json:"policies"`
ret, ok := ext.(T)
if ok == false {
return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
}
func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %s", node.Id, signal)
return nil
return ret, nil
}
func (node *SimpleNode) ID() NodeID {
return node.Id
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExtension struct {
Delegations NodeMap
}
func (node *SimpleNode) Type() NodeType {
return NodeType("simple_node")
func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func (node *SimpleNode) Serialize() ([]byte, error) {
j := NewSimpleNodeJSON(node)
return json.MarshalIndent(&j, "", " ")
func LoadACLExtension(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
func (node *SimpleNode) LockState(write bool) {
if write == true {
node.state_mutex.Lock()
} else {
node.state_mutex.RLock()
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
func (node *SimpleNode) UnlockState(write bool) {
if write == true {
node.state_mutex.Unlock()
} else {
node.state_mutex.RUnlock()
}
delegations := NodeMap{}
for _, str := range(j.Delegations) {
id, err := ParseID(str)
if err != nil {
return nil, err
}
func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON {
policy_ids := make([]string, len(node.PolicyMap))
i := 0
for id, _ := range(node.PolicyMap) {
policy_ids[i] = id.String()
i += 1
node, err := LoadNode(ctx, id)
if err != nil {
return nil, err
}
return SimpleNodeJSON{
Policies: policy_ids,
}
delegations[id] = node
}
func RestoreSimpleNode(ctx *Context, node Node, j SimpleNodeJSON, nodes NodeMap) error {
node_ptr := node.NodeHandle()
for _, policy_str := range(j.Policies) {
policy_id, err := ParseID(policy_str)
if err != nil {
return err
return ACLExtension{
Delegations: delegations,
}, nil
}
policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes)
if err != nil {
return err
func (ext ACLExtension) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
policy, ok := policy_ptr.(Policy)
if ok == false {
return fmt.Errorf("%s is not a Policy", policy_id)
}
node_ptr.PolicyMap[policy_id] = policy
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
return nil
const ACLExtType = ExtType("ACL")
func (extension ACLExtension) Type() ExtType {
return ACLExtType
}
func LoadJSONNode[J any, N Node](init_func func(NodeID, J)(Node, error), restore_func func(*Context, N, J, NodeMap)error)func(*Context, NodeID, []byte, NodeMap)(Node, error) {
return func(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j J
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.ExtensionMap))
node_db := NodeDB{
Header: NodeDBHeader{
Magic: NODE_DB_MAGIC,
NumExtensions: uint32(len(extensions)),
},
Extensions: extensions,
}
node, err := init_func(id, j)
if err != nil {
return nil, err
}
nodes[id] = node
err = restore_func(ctx, node.(N), j, nodes)
i := 0
for ext_type, info := range(node.ExtensionMap) {
ser, err := info.Serialize()
if err != nil {
return nil, err
}
return node, nil
node_db.Extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: ext_type.Hash(),
Length: uint64(len(ser)),
},
Data: ser,
}
i += 1
}
var LoadSimpleNode = LoadJSONNode(func(id NodeID, j SimpleNodeJSON) (Node, error) {
node := NewSimpleNode(id)
return &node, nil
}, RestoreSimpleNode)
func (node *SimpleNode) Policies() []Policy {
ret := make([]Policy, len(node.PolicyMap))
i := 0
for _, policy := range(node.PolicyMap) {
ret[i] = policy
i += 1
return node_db.Serialize(), nil
}
return ret
func NewNode(id NodeID) Node {
return Node{
ID: id,
ExtensionMap: map[ExtType]Extension{},
}
}
func Allowed(context *StateContext, policies []Policy, node Node, resource string, action string, princ Node) error {
if princ == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
func Allowed(context *StateContext, principal *Node, action string, node *Node) error {
if principal == nil {
context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s", principal.ID, node.ID, action)
return fmt.Errorf("nil is not allowed to perform any actions")
}
if node.ID() == princ.ID() {
return nil
ext, exists := node.ExtensionMap[ACLExtType]
if exists == false {
return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID)
}
acl_ext := ext.(ACLExtension)
for _, policy_node := range(acl_ext.Delegations) {
ext, exists := policy_node.ExtensionMap[ACLPolicyExtType]
if exists == false {
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension")
continue
}
for _, policy := range(policies) {
if policy.Allows(context, node, resource, action, princ) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
policy_ext := ext.(ACLPolicyExtension)
if policy_ext.Allows(context, principal, action, node) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action)
return nil
}
}
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", princ.ID(), node.ID(), resource, action)
return fmt.Errorf("%s is not allowed to perform %s.%s on %s", princ.ID(), resource, action, node.ID())
context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", principal.ID, node.ID, action)
return fmt.Errorf("%s is not allowed to perform %s on %s", principal.ID, action, node.ID)
}
// Check that princ is allowed to signal this action,
// then send the signal to all the extensions of the node
func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, signal.String())
// Propagate the signal to registered listeners, if a listener isn't ready to receive the update
// send it a notification that it was closed and then close it
func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID(), signal.String())
err := UseStates(context, princ, NewLockInfo(node, []string{}), func(context *StateContext) error {
return Allowed(context, node.Policies(), node, "signal", signal.Type(), princ)
err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error {
return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node)
})
for _, ext := range(node.ExtensionMap) {
err = ext.Process(context, node, signal)
if err != nil {
return nil
}
return node.Process(context, signal)
}
func AttachPolicies(ctx *Context, node Node, policies ...Policy) error {
context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"policies"}), LockList(policies, nil)), func(context *StateContext) error {
for _, policy := range(policies) {
node.NodeHandle().PolicyMap[policy.ID()] = policy
}
return nil
})
}
// Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function
const NODE_DB_HEADER_LEN = 12
const NODE_DB_HEADER_LEN = 8
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type DBHeader struct {
type NodeDBHeader struct {
Magic uint32
TypeHash uint64
NumExtensions uint32
}
type NodeDB struct {
Header NodeDBHeader
Extensions []ExtensionDB
}
//TODO: add size safety checks
func NewNodeDB(data []byte) (NodeDB, error) {
var zero NodeDB
ptr := 0
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
ptr += NODE_DB_HEADER_LEN
if magic != NODE_DB_MAGIC {
return zero, fmt.Errorf("header has incorrect magic 0x%x", magic)
}
extensions := make([]ExtensionDB, num_extensions)
for i, _ := range(extensions) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(EXTENSION_DB_HEADER_LEN)
data_end := data_start + length
ext_data := cur[data_start:data_end]
extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: type_hash,
Length: length,
},
Data: ext_data,
}
func (header DBHeader) Serialize() []byte {
ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
NumExtensions: num_extensions,
},
Extensions: extensions,
}, nil
}
func (header NodeDBHeader) Serialize() []byte {
if header.Magic != NODE_DB_MAGIC {
panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic))
}
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint64(ret[4:12], header.TypeHash)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
return ret
}
func NewDBHeader(node_type NodeType) DBHeader {
return DBHeader{
Magic: NODE_DB_MAGIC,
TypeHash: node_type.Hash(),
func (node NodeDB) Serialize() []byte {
ser := node.Header.Serialize()
for _, extension := range(node.Extensions) {
ser = append(ser, extension.Serialize()...)
}
return ser
}
func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[8:16], header.Length)
return ret
}
func (extension ExtensionDB) Serialize() []byte {
header_bytes := extension.Header.Serialize()
return append(header_bytes, extension.Data...)
}
const EXTENSION_DB_HEADER_LEN = 16
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// Write multiple nodes to the database in a single transaction
@ -283,27 +351,21 @@ func WriteNodes(context *StateContext) error {
serialized_bytes := make([][]byte, len(context.Locked))
serialized_ids := make([][]byte, len(context.Locked))
i := 0
for _, node := range(context.Locked) {
// TODO, just write states from the context, and store the current states in the context
for id, _ := range(context.Locked) {
node, _ := context.Graph.Nodes[id]
if node == nil {
return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node")
return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node, maybe node isn't in the context")
}
ser, err := node.Serialize()
if err != nil {
return fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
}
header := NewDBHeader(node.Type())
db_data := append(header.Serialize(), ser...)
context.Graph.Log.Logf("db", "DB_WRITING_TYPE: %s - %+v %+v: %+v", node.ID(), node.Type(), header, node)
if err != nil {
return err
}
id_ser := node.ID().Serialize()
id_ser := node.ID.Serialize()
serialized_bytes[i] = db_data
serialized_bytes[i] = ser
serialized_ids[i] = id_ser
i++
@ -320,8 +382,13 @@ func WriteNodes(context *StateContext) error {
})
}
// Get the bytes associates with `id` from the database after unwrapping the header, or error
func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) {
// Recursively load a node from the database.
func LoadNode(ctx * Context, id NodeID) (*Node, error) {
node, exists := ctx.Nodes[id]
if exists == true {
return node,nil
}
var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error {
item, err := txn.Get(id.Serialize())
@ -334,80 +401,51 @@ func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) {
return nil
})
})
if err != nil {
ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err)
return 0, nil, err
}
if len(bytes) < NODE_DB_HEADER_LEN {
return 0, nil, fmt.Errorf("header for %s is %d/%d bytes", id, len(bytes), NODE_DB_HEADER_LEN)
}
header := DBHeader{}
header.Magic = binary.BigEndian.Uint32(bytes[0:4])
header.TypeHash = binary.BigEndian.Uint64(bytes[4:12])
if header.Magic != NODE_DB_MAGIC {
return 0, nil, fmt.Errorf("header for %s, invalid magic 0x%x", id, header.Magic)
}
node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN)
copy(node_bytes, bytes[NODE_DB_HEADER_LEN:])
ctx.Log.Logf("db", "DB_READ: %s %+v - %s", id, header, string(bytes))
return header.TypeHash, node_bytes, nil
}
// Load a Node from the database by ID
func LoadNode(ctx * Context, id NodeID) (Node, error) {
nodes := NodeMap{}
return LoadNodeRecurse(ctx, id, nodes)
return nil, err
}
// Recursively load a node from the database.
// It's expected that node_type.Load adds the newly loaded node to nodes before calling LoadNodeRecurse again.
func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) {
node, exists := nodes[id]
if exists == false {
type_hash, bytes, err := readNodeBytes(ctx, id)
// Parse the bytes from the DB
node_db, err := NewNodeDB(bytes)
if err != nil {
return nil, err
}
node_type, exists := ctx.Types[type_hash]
ctx.Log.Logf("db", "DB_LOADING_TYPE: %s - %+v", id, node_type)
if exists == false {
return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes)
}
// Create the blank node with the ID, and add it to the context
new_node := NewNode(id)
node = &new_node
ctx.Nodes[id] = node
if node_type.Load == nil {
return nil, fmt.Errorf("0x%x is an invalid node type, nil Load", type_hash)
// Parse each of the extensions from the db
for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
def, known := ctx.Extensions[type_hash]
if known == false {
return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash)
}
node, err = node_type.Load(ctx, id, bytes, nodes)
extension, err := def.Load(ctx, ext_db.Data)
if err != nil {
return nil, err
}
node.ExtensionMap[def.Type] = extension
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash)
}
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
}
return node, nil
}
func NewLockInfo(node Node, resources []string) LockMap {
return LockMap{
node.ID(): LockInfo{
func NewACLInfo(node *Node, resources []string) ACLMap {
return ACLMap{
node.ID: ACLInfo{
Node: node,
Resources: resources,
},
}
}
func NewLockMap(requests ...LockMap) LockMap {
reqs := LockMap{}
func NewACLMap(requests ...ACLMap) ACLMap {
reqs := ACLMap{}
for _, req := range(requests) {
for id, info := range(req) {
reqs[id] = info
@ -416,10 +454,10 @@ func NewLockMap(requests ...LockMap) LockMap {
return reqs
}
func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap {
reqs := LockMap{}
func ACLListM(m map[NodeID]*Node, resources[]string) ACLMap {
reqs := ACLMap{}
for _, node := range(m) {
reqs[node.ID()] = LockInfo{
reqs[node.ID] = ACLInfo{
Node: node,
Resources: resources,
}
@ -427,10 +465,10 @@ func LockListM[K Node](m map[NodeID]K, resources[]string) LockMap {
return reqs
}
func LockList[K Node](list []K, resources []string) LockMap {
reqs := LockMap{}
func ACLList(list []*Node, resources []string) ACLMap {
reqs := ACLMap{}
for _, node := range(list) {
reqs[node.ID()] = LockInfo{
reqs[node.ID] = ACLInfo{
Node: node,
Resources: resources,
}
@ -438,21 +476,40 @@ func LockList[K Node](list []K, resources []string) LockMap {
return reqs
}
type PolicyType string
func (policy PolicyType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type ExtType string
func (ext ExtType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("EXTENSION: %s", string(ext))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type NodeMap map[NodeID]Node
type NodeMap map[NodeID]*Node
type LockInfo struct {
Node Node
type ACLInfo struct {
Node *Node
Resources []string
}
type LockMap map[NodeID]LockInfo
type ACLMap map[NodeID]ACLInfo
type ExtMap map[uint64]Extension
// Context of running state usage(read/write)
type StateContext struct {
// Type of the state context
Type string
// The wrapped graph context
Graph *Context
Permissions map[NodeID]LockMap
Locked NodeMap
// Granted permissions in the context
Permissions map[NodeID]ACLMap
// Locked extensions in the context
Locked map[NodeID]*Node
// Context state for validation
Started bool
Finished bool
}
@ -477,8 +534,8 @@ func NewReadContext(ctx *Context) *StateContext {
return &StateContext{
Type: "read",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Permissions: map[NodeID]ACLMap{},
Locked: map[NodeID]*Node{},
Started: false,
Finished: false,
}
@ -488,8 +545,8 @@ func NewWriteContext(ctx *Context) *StateContext {
return &StateContext{
Type: "write",
Graph: ctx,
Permissions: map[NodeID]LockMap{},
Locked: NodeMap{},
Permissions: map[NodeID]ACLMap{},
Locked: map[NodeID]*Node{},
Started: false,
Finished: false,
}
@ -515,8 +572,8 @@ func del[K comparable](list []K, val K) []K {
// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read
// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read
func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil {
func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UseStates")
}
@ -529,16 +586,16 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
context.Started = true
}
new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
new_locks := []*Node{}
_, princ_locked := context.Locked[principal.ID]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", princ.ID().String())
princ.LockState(false)
new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", principal.ID.String())
principal.Lock.RLock()
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
new_permissions := LockMap{}
princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := ACLMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
@ -550,20 +607,20 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
id := node.ID
if id != princ.ID() {
if id != principal.ID {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.LockState(false)
node.Lock.RLock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
node_permissions = ACLInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
@ -575,11 +632,11 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
}
if already_granted == false {
err := Allowed(context, node.Policies(), node, resource, "read", princ)
err := Allowed(context, principal, fmt.Sprintf("%s.read", resource), node)
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.UnlockState(false)
n.Lock.RUnlock()
}
return err
}
@ -589,19 +646,19 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
}
for _, node := range(new_locks) {
context.Locked[node.ID()] = node
context.Locked[node.ID] = node
}
context.Permissions[princ.ID()] = new_permissions
context.Permissions[principal.ID] = new_permissions
err = state_fn(context)
context.Permissions[princ.ID()] = princ_permissions
context.Permissions[principal.ID] = princ_permissions
for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID().String())
delete(context.Locked, node.ID())
node.UnlockState(false)
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID.String())
delete(context.Locked, node.ID)
node.Lock.RUnlock()
}
return err
@ -609,8 +666,8 @@ func UseStates(context *StateContext, princ Node, new_nodes LockMap, state_fn St
// Add nodes to an existing write context and call nodes_fn with nodes locked for read
// If context is nil
func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn StateFn) error {
if princ == nil || new_nodes == nil || state_fn == nil {
func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UpdateStates")
}
@ -625,16 +682,16 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
final = true
}
new_locks := []Node{}
_, princ_locked := context.Locked[princ.ID()]
new_locks := []*Node{}
_, princ_locked := context.Locked[principal.ID]
if princ_locked == false {
new_locks = append(new_locks, princ)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", princ.ID().String())
princ.LockState(true)
new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", principal.ID.String())
principal.Lock.Lock()
}
princ_permissions, princ_exists := context.Permissions[princ.ID()]
new_permissions := LockMap{}
princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := ACLMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
@ -646,20 +703,20 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID()
id := node.ID
if id != princ.ID() {
if id != principal.ID {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.LockState(true)
node.Lock.Lock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = LockInfo{Node: node, Resources: []string{}}
node_permissions = ACLInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
@ -671,11 +728,11 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
}
if already_granted == false {
err := Allowed(context, node.Policies(), node, resource, "write", princ)
err := Allowed(context, principal, fmt.Sprintf("%s.write", resource), node)
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.UnlockState(true)
n.Lock.Unlock()
}
return err
}
@ -685,10 +742,10 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
}
for _, node := range(new_locks) {
context.Locked[node.ID()] = node
context.Locked[node.ID] = node
}
context.Permissions[princ.ID()] = new_permissions
context.Permissions[principal.ID] = new_permissions
err = state_fn(context)
@ -699,7 +756,7 @@ func UpdateStates(context *StateContext, princ Node, new_nodes LockMap, state_fn
}
for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.UnlockState(true)
node.Lock.Unlock()
}
}

@ -5,356 +5,119 @@ import (
"fmt"
)
// A policy represents a set of rules attached to a Node that allow principals to perform actions on it
type Policy interface {
Node
// Returns true if the principal is allowed to perform the action on the resource
Allows(context *StateContext, node Node, resource string, action string, principal Node) bool
Serialize() ([]byte, error)
Allows(context *StateContext, principal *Node, action string, node *Node) bool
}
type NodeActions map[string][]string
func (actions NodeActions) Allows(resource string, action string) bool {
for _, a := range(actions[""]) {
if a == action || a == "*" {
return true
}
}
resource_actions, exists := actions[resource]
if exists == true {
for _, a := range(resource_actions) {
if a == action || a == "*" {
return true
}
}
}
return false
}
func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) NodeActions {
if resource_actions == nil {
resource_actions = NodeActions{}
}
// Wildcard actions, all actions in "" will be allowed on all resources
if wildcard_actions == nil {
wildcard_actions = []string{}
func LoadAllNodesPolicy(ctx *Context, data []byte) (Policy, error) {
var policy AllNodesPolicy
err := json.Unmarshal(data, &policy)
if err != nil {
return policy, err
}
resource_actions[""] = wildcard_actions
return resource_actions
return policy, nil
}
type PerNodePolicy struct {
SimpleNode
Actions map[NodeID]NodeActions
type AllNodesPolicy struct {
Actions []string `json:"actions"`
}
type PerNodePolicyJSON struct {
SimpleNodeJSON
Actions map[string]map[string][]string `json:"actions"`
func (policy AllNodesPolicy) Type() PolicyType {
return PolicyType("simple_policy")
}
func (policy *PerNodePolicy) Type() NodeType {
return NodeType("per_node_policy")
func (policy AllNodesPolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(&policy, "", " ")
}
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
actions := map[string]map[string][]string{}
for principal, resource_actions := range(policy.Actions) {
actions[principal.String()] = resource_actions
// Extension to allow a node to hold ACL policies
type ACLPolicyExtension struct {
Policies map[PolicyType]Policy
}
return json.MarshalIndent(&PerNodePolicyJSON{
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: actions,
}, "", " ")
}
func NewPerNodePolicy(id NodeID, actions map[NodeID]NodeActions) PerNodePolicy {
if actions == nil {
actions = map[NodeID]NodeActions{}
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
type PolicyInfo struct {
Load PolicyLoadFunc
Type PolicyType
}
return PerNodePolicy{
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
type ACLPolicyExtensionContext struct {
Types map[PolicyType]PolicyInfo
}
var LoadPerNodePolicy = LoadJSONNode(func(id NodeID, j PerNodePolicyJSON) (Node, error) {
actions := map[NodeID]NodeActions{}
for principal_str, node_actions := range(j.Actions) {
principal_id, err := ParseID(principal_str)
func (ext ACLPolicyExtension) Serialize() ([]byte, error) {
policies := map[string][]byte{}
for name, policy := range(ext.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
actions[principal_id] = node_actions
}
policy := NewPerNodePolicy(id, actions)
return &policy, nil
}, func(ctx *Context, node Node, j PerNodePolicyJSON, nodes NodeMap) error {
return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes)
})
func (policy *PerNodePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
node_actions, exists := policy.Actions[principal.ID()]
if exists == false {
return false
}
if node_actions.Allows(resource, action) == true {
return true
}
return false
}
type SimplePolicy struct {
SimpleNode
Actions NodeActions
}
type SimplePolicyJSON struct {
SimpleNodeJSON
Actions map[string][]string `json:"actions"`
}
func (policy *SimplePolicy) Type() NodeType {
return NodeType("simple_policy")
}
func NewSimplePolicyJSON(policy *SimplePolicy) SimplePolicyJSON {
return SimplePolicyJSON{
SimpleNodeJSON: NewSimpleNodeJSON(&policy.SimpleNode),
Actions: policy.Actions,
}
}
func (policy *SimplePolicy) Serialize() ([]byte, error) {
j := NewSimplePolicyJSON(policy)
return json.MarshalIndent(&j, "", " ")
}
func NewSimplePolicy(id NodeID, actions NodeActions) SimplePolicy {
if actions == nil {
actions = NodeActions{}
}
return SimplePolicy{
SimpleNode: NewSimpleNode(id),
Actions: actions,
}
}
var LoadSimplePolicy = LoadJSONNode(func(id NodeID, j SimplePolicyJSON) (Node, error) {
policy := NewSimplePolicy(id, j.Actions)
return &policy, nil
}, func(ctx *Context, node Node, j SimplePolicyJSON, nodes NodeMap) error {
return RestoreSimpleNode(ctx, node.NodeHandle(), j.SimpleNodeJSON, nodes)
})
func (policy *SimplePolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
return policy.Actions.Allows(resource, action)
}
type DependencyPolicy struct {
SimplePolicy
}
func (policy *DependencyPolicy) Type() NodeType {
return NodeType("dependency_policy")
}
func NewDependencyPolicy(id NodeID, actions NodeActions) DependencyPolicy {
return DependencyPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
}
}
func (policy *DependencyPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
lockable, ok := node.(LockableNode)
if ok == false {
return false
}
for _, dep := range(lockable.LockableHandle().Dependencies) {
if dep.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
}
}
return false
}
type RequirementPolicy struct {
SimplePolicy
}
func (policy *RequirementPolicy) Type() NodeType {
return NodeType("dependency_policy")
}
func NewRequirementPolicy(id NodeID, actions NodeActions) RequirementPolicy {
return RequirementPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
}
}
func (policy *RequirementPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
lockable_node, ok := node.(LockableNode)
if ok == false {
return false
}
lockable := lockable_node.LockableHandle()
for _, req := range(lockable.Requirements) {
if req.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
}
}
return false
policies[string(name)] = ser
}
type ParentPolicy struct {
SimplePolicy
return json.MarshalIndent(&struct{
Policies map[string][]byte `json:"policies"`
}{
Policies: policies,
}, "", " ")
}
func (policy *ParentPolicy) Type() NodeType {
return NodeType("parent_policy")
func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func NewParentPolicy(id NodeID, actions NodeActions) ParentPolicy {
return ParentPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
var j struct {
Policies map[string][]byte `json:"policies"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
func (policy *ParentPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
thread_node, ok := node.(ThreadNode)
if ok == false {
return false
}
thread := thread_node.ThreadHandle()
if thread.Owner != nil {
if thread.Owner.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
policies := map[PolicyType]Policy{}
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext)
for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false {
return nil, fmt.Errorf("%s is not a known policy type", name)
}
policy, err := policy_def.Load(ctx, ser)
if err != nil {
return nil, err
}
return false
policies[PolicyType(name)] = policy
}
type ChildrenPolicy struct {
SimplePolicy
return ACLPolicyExtension{
Policies: policies,
}, nil
}
func (policy *ChildrenPolicy) Type() NodeType {
return NodeType("children_policy")
const ACLPolicyExtType = ExtType("ACL_POLICIES")
func (ext ACLPolicyExtension) Type() ExtType {
return ACLPolicyExtType
}
func NewChildrenPolicy(id NodeID, actions NodeActions) ChildrenPolicy {
return ChildrenPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
// Check if the extension allows the principal to perform action on node
func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for _, policy := range(ext.Policies) {
if policy.Allows(context, principal, action, node) == true {
return true
}
}
func (policy *ChildrenPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
thread_node, ok := node.(ThreadNode)
if ok == false {
return false
}
thread := thread_node.ThreadHandle()
for _, info := range(thread.Children) {
if info.Child.ID() == principal.ID() {
return policy.Actions.Allows(resource, action)
func (policy AllNodesPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for _, a := range(policy.Actions) {
if a == action {
return true
}
}
return false
}
type UserOfPolicy struct {
SimplePolicy
Target GroupNode
}
type UserOfPolicyJSON struct {
SimplePolicyJSON
Target string `json:"target"`
}
func (policy *UserOfPolicy) Type() NodeType {
return NodeType("user_of_policy")
}
func (policy *UserOfPolicy) Serialize() ([]byte, error) {
target := ""
if policy.Target != nil {
target = policy.Target.ID().String()
}
return json.MarshalIndent(&UserOfPolicyJSON{
SimplePolicyJSON: NewSimplePolicyJSON(&policy.SimplePolicy),
Target: target,
}, "", " ")
}
func NewUserOfPolicy(id NodeID, actions NodeActions) UserOfPolicy {
return UserOfPolicy{
SimplePolicy: NewSimplePolicy(id, actions),
Target: nil,
}
}
var LoadUserOfPolicy = LoadJSONNode(func(id NodeID, j UserOfPolicyJSON) (Node, error) {
policy := NewUserOfPolicy(id, j.Actions)
return &policy, nil
}, func(ctx *Context, policy *UserOfPolicy, j UserOfPolicyJSON, nodes NodeMap) error {
if j.Target != "" {
target_id, err := ParseID(j.Target)
if err != nil {
return err
}
target_node, err := LoadNodeRecurse(ctx, target_id, nodes)
if err != nil {
return err
}
target, ok := target_node.(GroupNode)
if ok == false {
return fmt.Errorf("%s is not a GroupNode", target_node.ID())
}
policy.Target = target
return nil
}
return RestoreSimpleNode(ctx, policy, j.SimpleNodeJSON, nodes)
})
func (policy *UserOfPolicy) Allows(context *StateContext, node Node, resource string, action string, principal Node) bool {
if policy.Target != nil {
allowed := false
err := UseStates(context, policy.Target, NewLockInfo(policy.Target, []string{"users"}), func(context *StateContext) error {
for _, user := range(policy.Target.Users()) {
if user.ID() == principal.ID() {
allowed = policy.Actions.Allows(resource, action)
return nil
}
}
return nil
})
if err != nil {
return false
}
return allowed
}
return false
}

@ -8,24 +8,64 @@ import (
"encoding/json"
)
type ThreadExt struct {
Actions ThreadActions
Handlers ThreadHandlers
SignalChan chan GraphSignal
TimeoutChan <-chan time.Time
ChildWaits sync.WaitGroup
ActiveLock sync.Mutex
Active bool
StateName string
Parent *Node
Children map[NodeID]ChildInfo
ActionQueue []QueuedAction
NextAction *QueuedAction
}
func (ext *ThreadExt) Serialize() ([]byte, error) {
return nil, fmt.Errorf("NOT_IMPLEMENTED")
}
const ThreadExtType = ExtType("THREAD")
func (ext *ThreadExt) Type() ExtType {
return ThreadExtType
}
func (ext *ThreadExt) ChildList() []*Node {
ret := make([]*Node, len(ext.Children))
i := 0
for _, info := range(ext.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
// Assumed that thread is already locked for signal
func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", thread.ID())
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
var err error
switch signal.Direction() {
case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
if thread.Parent != nil {
return Signal(context, thread.Parent, thread, signal)
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
if ext.Parent != nil {
return Signal(context, ext.Parent, node, signal)
} else {
return nil
}
})
case Down:
err = UseStates(context, thread, NewLockInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, info := range(thread.Children) {
err := Signal(context, info.Child, thread, signal)
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
for _, info := range(ext.Children) {
err := Signal(context, info.Child, node, signal)
if err != nil {
return err
}
@ -37,91 +77,121 @@ func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
default:
return fmt.Errorf("Invalid signal direction %d", signal.Direction())
}
ext.SignalChan <- signal
return err
}
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
thread.Chan <- signal
return thread.Lockable.Process(context, signal)
child_ext, err := GetExt[*ThreadExt](child)
if err != nil {
return err
}
// Requires thread and childs thread to be locked for write
func UnlinkThreads(ctx * Context, node ThreadNode, child_node ThreadNode) error {
thread := node.ThreadHandle()
child := child_node.ThreadHandle()
_, is_child := thread.Children[child_node.ID()]
return UpdateStates(context, principal, ACLMap{
thread.ID: ACLInfo{thread, []string{"children"}},
child.ID: ACLInfo{child, []string{"parent"}},
}, func(context *StateContext) error {
_, is_child := thread_ext.Children[child.ID]
if is_child == false {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID())
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
}
child.Parent = nil
delete(thread.Children, child.ID())
delete(thread_ext.Children, child.ID)
child_ext.Parent = nil
return nil
})
}
func checkIfChild(context *StateContext, target ThreadNode, cur ThreadNode) bool {
for _, info := range(cur.ThreadHandle().Children) {
if info.Child.ID() == target.ID() {
return true
func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
for _, info := range(cur.Children) {
child := info.Child
if child.ID == id {
return true, nil
}
is_child := false
UpdateStates(context, cur, NewLockMap(
NewLockInfo(info.Child, []string{"children"}),
), func(context *StateContext) error {
is_child = checkIfChild(context, target, info.Child)
return nil
child_ext, err := GetExt[*ThreadExt](child)
if err != nil {
return false, err
}
var is_child bool
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
is_child, err = checkIfChild(context, id, child_ext)
return err
})
if err != nil {
return false, err
}
if is_child {
return true
return true, nil
}
}
return false
return false, nil
}
// Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, princ Node, thread_node ThreadNode, info ChildInfo) error {
if context == nil || thread_node == nil || info.Child == nil {
func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
if context == nil || principal == nil || thread == nil || info.Child == nil {
return fmt.Errorf("invalid input")
}
thread := thread_node.ThreadHandle()
child := info.Child.ThreadHandle()
child_node := info.Child
if thread.ID() == child.ID() {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
child := info.Child
if thread.ID == child.ID {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
}
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
return UpdateStates(context, princ, LockMap{
child.ID(): LockInfo{Node: child_node, Resources: []string{"parent"}},
thread.ID(): LockInfo{Node: thread_node, Resources: []string{"children"}},
child_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
return UpdateStates(context, principal, ACLMap{
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error {
if child.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
if child_ext.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
}
if checkIfChild(context, thread, child) == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID())
is_child, err := checkIfChild(context, thread.ID, child_ext)
if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
}
if checkIfChild(context, child, thread) == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID())
is_child, err = checkIfChild(context, child.ID, thread_ext)
if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
}
// TODO check for info types
thread.Children[child.ID()] = info
child.Parent = thread_node
thread_ext.Children[child.ID] = info
child_ext.Parent = thread
return nil
})
}
type ThreadAction func(*Context, ThreadNode)(string, error)
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, ThreadNode, GraphSignal)(string, error)
type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
type InfoType string
@ -129,6 +199,10 @@ func (t InfoType) String() string {
return string(t)
}
type ThreadInfo interface {
Serializable[InfoType]
}
// Data required by a parent thread to restore it's children
type ParentThreadInfo struct {
Start bool `json:"start"`
@ -136,22 +210,23 @@ type ParentThreadInfo struct {
RestoreAction string `json:"restore_action"`
}
func NewParentThreadInfo(start bool, start_action string, restore_action string) ParentThreadInfo {
return ParentThreadInfo{
Start: start,
StartAction: start_action,
RestoreAction: restore_action,
const ParentThreadInfoType = InfoType("PARENT")
func (info *ParentThreadInfo) Type() InfoType {
return ParentThreadInfoType
}
func (info *ParentThreadInfo) Serialize() ([]byte, error) {
return json.MarshalIndent(info, "", " ")
}
type ChildInfo struct {
Child ThreadNode
Infos map[InfoType]interface{}
Child *Node
Infos map[InfoType]ThreadInfo
}
func NewChildInfo(child ThreadNode, infos map[InfoType]interface{}) ChildInfo {
func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
if infos == nil {
infos = map[InfoType]interface{}{}
infos = map[InfoType]ThreadInfo{}
}
return ChildInfo{
@ -165,110 +240,21 @@ type QueuedAction struct {
Action string `json:"action"`
}
type ThreadNode interface {
LockableNode
ThreadHandle() *Thread
}
type Thread struct {
Lockable
Actions ThreadActions
Handlers ThreadHandlers
TimeoutChan <-chan time.Time
Chan chan GraphSignal
ChildWaits sync.WaitGroup
Active bool
ActiveLock sync.Mutex
StateName string
Parent ThreadNode
Children map[NodeID]ChildInfo
InfoTypes []InfoType
ActionQueue []QueuedAction
NextAction *QueuedAction
}
func (thread *Thread) QueueAction(end time.Time, action string) {
thread.ActionQueue = append(thread.ActionQueue, QueuedAction{end, action})
thread.NextAction, thread.TimeoutChan = thread.SoonestAction()
}
func (thread *Thread) ClearActionQueue() {
thread.ActionQueue = []QueuedAction{}
thread.NextAction = nil
thread.TimeoutChan = nil
}
func (thread *Thread) ThreadHandle() *Thread {
return thread
}
func (thread *Thread) Type() NodeType {
return NodeType("thread")
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = ext.SoonestAction()
}
func (thread *Thread) Serialize() ([]byte, error) {
thread_json := NewThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ")
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func (thread *Thread) ChildList() []ThreadNode {
ret := make([]ThreadNode, len(thread.Children))
i := 0
for _, info := range(thread.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
type ThreadJSON struct {
LockableJSON
Parent string `json:"parent"`
Children map[string]map[string]interface{} `json:"children"`
ActionQueue []QueuedAction `json:"action_queue"`
StateName string `json:"state_name"`
InfoTypes []InfoType `json:"info_types"`
}
func NewThreadJSON(thread *Thread) ThreadJSON {
children := map[string]map[string]interface{}{}
for id, info := range(thread.Children) {
tmp := map[string]interface{}{}
for name, i := range(info.Infos) {
tmp[name.String()] = i
}
children[id.String()] = tmp
}
parent_id := ""
if thread.Parent != nil {
parent_id = thread.Parent.ID().String()
}
lockable_json := NewLockableJSON(&thread.Lockable)
return ThreadJSON{
Parent: parent_id,
Children: children,
ActionQueue: thread.ActionQueue,
StateName: thread.StateName,
LockableJSON: lockable_json,
InfoTypes: thread.InfoTypes,
}
}
var LoadThread = LoadJSONNode(func(id NodeID, j ThreadJSON) (Node, error) {
thread := NewThread(id, j.Name, j.StateName, j.InfoTypes, BaseThreadActions, BaseThreadHandlers)
return &thread, nil
}, RestoreThread)
func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(thread.ActionQueue) {
for _, action := range(ext.ActionQueue) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
@ -281,55 +267,6 @@ func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
}
}
func RestoreThread(ctx *Context, thread ThreadNode, j ThreadJSON, nodes NodeMap) error {
thread_ptr := thread.ThreadHandle()
thread_ptr.ActionQueue = j.ActionQueue
thread_ptr.NextAction, thread_ptr.TimeoutChan = thread_ptr.SoonestAction()
if j.Parent != "" {
parent_id, err := ParseID(j.Parent)
if err != nil {
return err
}
p, err := LoadNodeRecurse(ctx, parent_id, nodes)
if err != nil {
return err
}
p_t, ok := p.(ThreadNode)
if ok == false {
return err
}
thread_ptr.Parent = p_t
}
for id_str, info_raw := range(j.Children) {
id, err := ParseID(id_str)
if err != nil {
return err
}
child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
}
child_t, ok := child_node.(ThreadNode)
if ok == false {
return fmt.Errorf("%+v is not a Thread as expected", child_node)
}
parsed_info, err := DeserializeChildInfo(ctx, info_raw)
if err != nil {
return err
}
thread_ptr.Children[id] = ChildInfo{child_t, parsed_info}
}
return RestoreLockable(ctx, thread, j.LockableJSON, nodes)
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{})
@ -357,141 +294,133 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) {
},
}
func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) {
ret := map[InfoType]interface{}{}
for type_str, info_raw := range(infos_raw) {
info_type := InfoType(type_str)
deserializer, exists := deserializers[info_type]
if exists == false {
return nil, fmt.Errorf("No deserializer for %s", info_type)
}
var err error
ret[info_type], err = deserializer(info_raw)
if err != nil {
return nil, err
}
}
return ret, nil
}
const THREAD_SIGNAL_BUFFER_SIZE = 128
func NewThread(id NodeID, name string, state_name string, info_types []InfoType, actions ThreadActions, handlers ThreadHandlers) Thread {
return Thread{
Lockable: NewLockable(id, name),
InfoTypes: info_types,
StateName: state_name,
Chan: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
Children: map[NodeID]ChildInfo{},
func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
return ThreadExt{
Actions: actions,
Handlers: handlers,
SignalChan: make(chan GraphSignal, buffer),
TimeoutChan: nil,
Active: false,
StateName: state_name,
Parent: nil,
Children: map[NodeID]ChildInfo{},
ActionQueue: []QueuedAction{},
NextAction: nil,
}
}
func (thread *Thread) SetActive(active bool) error {
thread.ActiveLock.Lock()
defer thread.ActiveLock.Unlock()
if thread.Active == true && active == true {
return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.Active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
func (ext *ThreadExt) SetActive(active bool) error {
ext.ActiveLock.Lock()
defer ext.ActiveLock.Unlock()
if ext.Active == true && active == true {
return fmt.Errorf("alreday active, cannot set active")
} else if ext.Active == false && active == false {
return fmt.Errorf("already inactive, canot set inactive")
}
thread.Active = active
ext.Active = active
return nil
}
func (thread *Thread) SetState(state string) error {
thread.StateName = state
func (ext *ThreadExt) SetState(state string) error {
ext.StateName = state
return nil
}
// Requires the read permission of threads children
func FindChild(context *StateContext, princ Node, node ThreadNode, id NodeID) ThreadNode {
if node == nil {
func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
if thread == nil {
panic("cannot recurse through nil")
}
thread := node.ThreadHandle()
if id == thread.ID() {
return thread
if id == thread.ID {
return thread, nil
}
for _, info := range thread.Children {
var result ThreadNode
UseStates(context, princ, NewLockInfo(info.Child, []string{"children"}), func(context *StateContext) error {
result = FindChild(context, princ, info.Child, id)
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return nil, err
}
var found *Node = nil
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, info := range(thread_ext.Children) {
found, err = FindChild(context, principal, info.Child, id)
if err != nil {
return err
}
if found != nil {
return nil
})
if result != nil {
return result
}
}
return nil
})
return found, err
}
func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action string) {
thread.ChildWaits.Add(1)
go func(child ThreadNode) {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID())
defer thread.ChildWaits.Done()
func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
thread_ext.ChildWaits.Add(1)
go func(child *Node) {
defer thread_ext.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action)
if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID(), err)
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
} else {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID())
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
}
}(child)
}
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, node ThreadNode, first_action string) error {
// Start the thread, error if double-started
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
err := thread.SetActive(true)
func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
err = thread_ext.SetActive(true)
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
return err
}
next_action := first_action
for next_action != "" {
action, exists := thread.Actions[next_action]
action, exists := thread_ext.Actions[next_action]
if exists == false {
error_str := fmt.Sprintf("%s is not a valid action", next_action)
return errors.New(error_str)
}
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action)
next_action, err = action(ctx, node)
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
next_action, err = action(ctx, thread, thread_ext)
if err != nil {
return err
}
}
err = thread.SetActive(false)
err = thread_ext.SetActive(false)
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID())
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
return nil
}
func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %+v", signal)
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockMap(
NewLockInfo(node, []string{"children"}),
), func(context *StateContext) error {
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
sig, ok := signal.(IDSignal)
if ok == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
return nil
}
info, exists := thread.Children[sig.ID]
info, exists := thread_ext.Children[sig.ID]
if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil
@ -502,7 +431,7 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin
}
if parent_info.Start == true {
ChildGo(ctx, thread, info.Child, parent_info.StartAction)
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
}
return nil
})
@ -517,28 +446,26 @@ func ThreadChildLinked(ctx *Context, node ThreadNode, signal GraphSignal) (strin
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string, error) {
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
sig, ok := signal.(StartChildSignal)
if ok == false {
return "wait", nil
}
thread := node.ThreadHandle()
context := NewWriteContext(ctx)
return "wait", UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
info, exists:= thread.Children[sig.ID]
return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
info, exists:= thread_ext.Children[sig.ID]
if exists == false {
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID())
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
}
return UpdateStates(context, node, NewLockInfo(info.Child, []string{"start"}), func(context *StateContext) error {
return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error {
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
if exists == false {
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
}
parent_info.Start = true
ChildGo(ctx, thread, info.Child, sig.Action)
ChildGo(ctx, thread_ext, info.Child, sig.Action)
return nil
})
@ -547,19 +474,23 @@ func ThreadStartChild(ctx *Context, node ThreadNode, signal GraphSignal) (string
// Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, node ThreadNode, start bool) error {
thread := node.ThreadHandle()
func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, node, LockList(thread.ChildList(), []string{"start"}), func(context *StateContext) error {
for _, info := range(thread.Children) {
return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"start", "state"}), func(context *StateContext) error {
for _, info := range(thread_ext.Children) {
child_ext, err := GetExt[*ThreadExt](info.Child)
if err != nil {
return err
}
parent_info := info.Infos["parent"].(*ParentThreadInfo)
if parent_info.Start == true && info.Child.ThreadHandle().StateName != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID(), info.Child.ID())
if parent_info.Start == true && child_ext.StateName != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true {
ChildGo(ctx, thread, info.Child, parent_info.StartAction)
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
} else {
ChildGo(ctx, thread, info.Child, parent_info.RestoreAction)
ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
}
}
}
@ -571,73 +502,70 @@ func ThreadRestore(ctx * Context, node ThreadNode, start bool) error {
// Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context
// Returns "wait", nil on success, so the first return value can be ignored safely
func ThreadStart(ctx * Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node)
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
if err != nil {
return err
}
return thread.SetState("started")
return thread_ext.SetState("started")
})
if err != nil {
return "", err
}
context = NewReadContext(ctx)
return "wait", Signal(context, node, node, NewStatusSignal("started", node.ID()))
return "wait", Signal(context, thread, thread, NewStatusSignal("started", thread.ID))
}
func ThreadWait(ctx * Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID(), thread.ActionQueue)
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
for {
select {
case signal := <- thread.Chan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal)
signal_fn, exists := thread.Handlers[signal.Type()]
case signal := <- thread_ext.SignalChan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
signal_fn, exists := thread_ext.Handlers[signal.Type()]
if exists == true {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type())
return signal_fn(ctx, node, signal)
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
return signal_fn(ctx, thread, thread_ext, signal)
} else {
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID(), signal.Type())
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
}
case <- thread.TimeoutChan:
case <- thread_ext.TimeoutChan:
timeout_action := ""
context := NewWriteContext(ctx)
err := UpdateStates(context, node, NewLockMap(NewLockInfo(node, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread.NextAction.Action
thread.NextAction, thread.TimeoutChan = thread.SoonestAction()
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread_ext.NextAction.Action
thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction()
return nil
})
if err != nil {
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID(), err)
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
}
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), timeout_action)
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
return timeout_action, nil
}
}
}
func ThreadFinish(ctx *Context, node ThreadNode) (string, error) {
thread := node.ThreadHandle()
func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
context := NewWriteContext(ctx)
return "", UpdateStates(context, node, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
err := thread.SetState("finished")
return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread_ext.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, map[NodeID]LockableNode{node.ID(): node}, node)
return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
})
}
var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
context := NewReadContext(ctx)
err := Signal(context, node, node, NewStatusSignal("aborted", node.ID()))
err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
if err != nil {
return "", err
}
@ -645,9 +573,9 @@ func ThreadAbort(ctx * Context, node ThreadNode, signal GraphSignal) (string, er
}
// Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, node ThreadNode, signal GraphSignal) (string, error) {
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) {
context := NewReadContext(ctx)
err := Signal(context, node, node, NewStatusSignal("stopped", node.ID()))
err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
return "finish", err
}