Switched from thread being the callback engine to node being the callback engine

gql_cataclysm
noah metz 2023-07-27 15:27:14 -06:00
parent 7a7a9c95a3
commit 3ad969a5ca
16 changed files with 316 additions and 1171 deletions

@ -135,18 +135,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
thread_ctx := NewThreadExtContext() err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, thread_ctx)
if err != nil {
return nil, err
}
err = thread_ctx.RegisterThreadType(GQLThreadType, gql_actions, gql_handlers)
if err != nil {
return nil, err
}
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -30,7 +30,6 @@ import (
"encoding/pem" "encoding/pem"
) )
const GQLThreadType = ThreadType("GQL")
const GQLNodeType = NodeType("GQL") const GQLNodeType = NodeType("GQL")
type AuthReqJSON struct { type AuthReqJSON struct {
@ -760,7 +759,7 @@ func NewGQLExtContext() *GQLExtContext {
Fields: graphql.Fields{}, Fields: graphql.Fields{},
}) })
mutation.AddFieldConfig("abort", GQLMutationAbort) mutation.AddFieldConfig("stop", GQLMutationStop)
mutation.AddFieldConfig("startChild", GQLMutationStartChild) mutation.AddFieldConfig("startChild", GQLMutationStartChild)
subscription := graphql.NewObject(graphql.ObjectConfig{ subscription := graphql.NewObject(graphql.ObjectConfig{
@ -790,10 +789,6 @@ func NewGQLExtContext() *GQLExtContext {
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = context.AddInterface(GQLInterfaceThread)
if err != nil {
panic(err)
}
schema, err := BuildSchema(&context) schema, err := BuildSchema(&context)
if err != nil { if err != nil {
@ -829,7 +824,7 @@ func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal {
return new_listener return new_listener
} }
func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) error { func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal Signal) {
ext.SubscribeLock.Lock() ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock() defer ext.SubscribeLock.Unlock()
@ -846,7 +841,7 @@ func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) err
} }
} }
ext.SubscribeListeners = active_listeners ext.SubscribeListeners = active_listeners
return nil return
} }
const GQLExtType = ExtType("gql_thread") const GQLExtType = ExtType("gql_thread")
@ -963,28 +958,11 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_
} }
} }
var gql_actions ThreadActions = ThreadActions{ func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error {
"wait": ThreadWait,
"restore": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
return "start_server", ThreadRestore(ctx, thread, thread_ext, false)
},
"start": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
_, err := ThreadStart(ctx, thread, thread_ext)
if err != nil {
return "", err
}
return "start_server", ThreadRestore(ctx, thread, thread_ext, true)
},
"start_server": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_ext, err := GetExt[*GQLExt](thread)
if err != nil {
return "", err
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, thread, gql_ext)) mux.HandleFunc("/auth", AuthHandler(ctx, node, gql_ext))
mux.HandleFunc("/gql", GQLHandler(ctx, thread, gql_ext)) mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, thread, gql_ext)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext))
// Server a graphiql interface(TODO make configurable whether to start this) // Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler()) mux.HandleFunc("/graphiql", GraphiQLHandler())
@ -1000,12 +978,12 @@ var gql_actions ThreadActions = ThreadActions{
l, err := net.Listen("tcp", http_server.Addr) l, err := net.Listen("tcp", http_server.Addr)
if err != nil { if err != nil {
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
} }
cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key) cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key)
if err != nil { if err != nil {
return "", err return err
} }
config := tls.Config{ config := tls.Config{
@ -1026,41 +1004,12 @@ var gql_actions ThreadActions = ThreadActions{
}(gql_ext) }(gql_ext)
context := NewWriteContext(ctx)
err = UpdateStates(context, thread, NewACLInfo(thread, []string{"http_server"}), func(context *StateContext) error {
gql_ext.tcp_listener = listener gql_ext.tcp_listener = listener
gql_ext.http_server = http_server gql_ext.http_server = http_server
return nil return nil
})
if err != nil {
return "", err
}
context = NewReadContext(ctx)
err = thread.Process(context, thread.ID, NewStatusSignal("server_started", thread.ID))
if err != nil {
return "", err
}
return "wait", nil
},
"finish": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_ext, err := GetExt[*GQLExt](thread)
if err != nil {
return "", err
} }
func StopGQLServer(gql_ext *GQLExt) {
gql_ext.http_server.Shutdown(context.TODO()) gql_ext.http_server.Shutdown(context.TODO())
gql_ext.http_done.Wait() gql_ext.http_done.Wait()
return ThreadFinish(ctx, thread, thread_ext)
},
} }
var gql_handlers ThreadHandlers = ThreadHandlers{
"child_linked": ThreadChildLinked,
"start_child": ThreadStartChild,
"abort": ThreadAbort,
"stop": ThreadStop,
}

@ -55,22 +55,6 @@ func addLockableInterfaceFields(gql *GQLInterface, gql_lockable *GQLInterface) {
}) })
} }
func AddThreadInterfaceFields(gql *GQLInterface) {
addThreadInterfaceFields(gql, GQLInterfaceThread)
}
func addThreadInterfaceFields(gql *GQLInterface, gql_thread *GQLInterface) {
AddNodeInterfaceFields(gql)
gql.Interface.AddFieldConfig("Children", &graphql.Field{
Type: gql_thread.List,
})
gql.Interface.AddFieldConfig("Parent", &graphql.Field{
Type: gql_thread.Interface,
})
}
func NodeHasExtensions(node *Node, extensions []ExtType) bool { func NodeHasExtensions(node *Node, extensions []ExtType) bool {
if node == nil { if node == nil {
return false return false
@ -136,8 +120,3 @@ var GQLInterfaceLockable = NewGQLInterface("Lockable", "DefaultLockable", []*gra
addLockableFields(gql.Default, gql.Interface, gql.List) addLockableFields(gql.Default, gql.Interface, gql.List)
}) })
var GQLInterfaceThread = NewGQLInterface("Thread", "DefaultThread", []*graphql.Interface{GQLInterfaceNode.Interface, }, []ExtType{ThreadExtType, LockableExtType}, func(gql *GQLInterface){
addThreadInterfaceFields(gql, gql)
}, func(gql *GQLInterface) {
addThreadFields(gql.Default, gql.Interface, gql.List)
})

@ -1,11 +1,10 @@
package graphvent package graphvent
import ( import (
"fmt"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
) )
var GQLMutationAbort = NewField(func()*graphql.Field { var GQLMutationStop = NewField(func()*graphql.Field {
gql_mutation_abort := &graphql.Field{ gql_mutation_stop := &graphql.Field{
Type: GQLTypeSignal.Type, Type: GQLTypeSignal.Type,
Args: graphql.FieldConfigArgument{ Args: graphql.FieldConfigArgument{
"id": &graphql.ArgumentConfig{ "id": &graphql.ArgumentConfig{
@ -13,39 +12,11 @@ var GQLMutationAbort = NewField(func()*graphql.Field {
}, },
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p) return StopSignal, nil
if err != nil {
return nil, err
}
id, err := ExtractID(p, "id")
if err != nil {
return nil, err
}
var node *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLMap(
NewACLInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) (error){
node, err = FindChild(context, ctx.User, ctx.Server, id)
if err != nil {
return err
}
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
return node.Process(context, ctx.User.ID, AbortSignal)
})
if err != nil {
return nil, err
}
return AbortSignal, nil
}, },
} }
return gql_mutation_abort return gql_mutation_stop
}) })
var GQLMutationStartChild = NewField(func()*graphql.Field{ var GQLMutationStartChild = NewField(func()*graphql.Field{
@ -64,7 +35,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
}, },
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p) /*_, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,10 +73,10 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }*/
// TODO: wait for the result of the signal to send back instead of just the signal // TODO: wait for the result of the signal to send back instead of just the signal
return signal, nil return nil, nil
}, },
} }

@ -4,7 +4,7 @@ import (
) )
var GQLQuerySelf = &graphql.Field{ var GQLQuerySelf = &graphql.Field{
Type: GQLInterfaceThread.Default, Type: GQLInterfaceNode.Default,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p) _, ctx, err := PrepResolve(p)
if err != nil { if err != nil {

@ -89,218 +89,36 @@ func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) { func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) // TODO figure out how nodes can read eachother
if err != nil { return "", nil
return nil, err
}
gql_ext, err := GetExt[*GQLExt](node)
if err != nil {
return nil, err
}
listen := ""
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = gql_ext.Listen
return nil
})
if err != nil {
return nil, err
}
return listen, nil
} }
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var parent *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = thread_ext.Parent
return nil
})
if err != nil {
return nil, err
}
return parent, nil
} }
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return "", nil
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var state string
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
state = thread_ext.State
return nil
})
if err != nil {
return nil, err
}
return state, nil
} }
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var children []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
children = thread_ext.ChildList()
return nil
})
if err != nil {
return nil, err
}
return children, nil
} }
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var requirements []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = make([]*Node, len(lockable_ext.Requirements))
i := 0
for _, req := range(lockable_ext.Requirements) {
requirements[i] = req
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return requirements, nil
} }
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var dependencies []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = make([]*Node, len(lockable_ext.Dependencies))
i := 0
for _, dep := range(lockable_ext.Dependencies) {
dependencies[i] = dep
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return dependencies, nil
} }
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var owner *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = lockable_ext.Owner
return nil
})
if err != nil {
return nil, err
}
return owner, nil
} }
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) { func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
group_ext, err := GetExt[*GroupExt](node)
if err != nil {
return nil, err
}
var members []*Node
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
members = make([]*Node, len(group_ext.Members))
i := 0
for _, member := range(group_ext.Members) {
members[i] = member
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return members, nil
} }
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {

@ -46,7 +46,7 @@ func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *N
var GQLSubscriptionSelf = NewField(func()*graphql.Field{ var GQLSubscriptionSelf = NewField(func()*graphql.Field{
gql_subscription_self := &graphql.Field{ gql_subscription_self := &graphql.Field{
Type: GQLInterfaceThread.Default, Type: GQLInterfaceNode.Default,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil return p.Source, nil
}, },

@ -3,114 +3,51 @@ package graphvent
import ( import (
"testing" "testing"
"time" "time"
"errors"
"crypto/rand" "crypto/rand"
"crypto/ecdh" "crypto/ecdh"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
) )
func TestGQL(t *testing.T) {
}
func TestGQLDB(t * testing.T) { func TestGQLDB(t * testing.T) {
ctx := logTestContext(t, []string{"thread", "test", "signal", "policy", "db"}) ctx := logTestContext(t, []string{"loop", "node", "thread", "test", "signal", "policy", "db"})
TestUserNodeType := NodeType("TEST_USER") TestUserNodeType := NodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
fatalErr(t, err) fatalErr(t, err)
u1 := NewNode(ctx, RandID(), TestUserNodeType) u1 := NewNode(ctx, RandID(), TestUserNodeType, nil)
ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
TestThreadNodeType := NodeType("TEST_THREAD")
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType})
fatalErr(t, err)
t1_p1 := NewParentOfPolicy(Actions{"signal.abort", "signal.stop", "state.write"})
t1_p2 := NewPerNodePolicy(NodeActions{
u1.ID: Actions{"parent.write"},
})
t1_thread, err := NewThreadExt(ctx, BaseThreadType, nil,nil, "init", nil)
fatalErr(t, err)
t1 := NewNode(ctx,
RandID(),
TestThreadNodeType,
NewACLExt(&t1_p1, &t1_p2),
t1_thread)
ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_p1 := NewChildOfPolicy(Actions{"signal.status"})
gql_p2 := NewPerNodePolicy(NodeActions{
u1.ID: Actions{"children.write", "dependencies.write"},
})
gql_thread, err := NewThreadExt(ctx, GQLThreadType, nil, nil, "init", nil)
fatalErr(t, err)
gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil) gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql := NewNode(ctx, RandID(), GQLNodeType, gql := NewNode(ctx, RandID(), GQLNodeType, nil,
gql_thread,
gql_ext, gql_ext,
listener_ext, listener_ext,
NewACLExt(&gql_p1, &gql_p2), NewACLExt(),
NewGroupExt(nil)) NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
info := ParentInfo{true, "start", "start"} err = gql.Signal(ctx, gql.ID, StopSignal)
context := NewWriteContext(ctx)
err = UpdateStates(context, u1, ACLMap{}, func(context *StateContext) error {
return LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{
ParentInfoType: &info,
}})
})
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener")
err = gql.Process(context, gql.ID, NewStatusSignal("child_linked", t1.ID))
fatalErr(t, err)
context = NewReadContext(ctx)
err = gql.Process(context, gql.ID, AbortSignal)
fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, ThreadAbortedError) == false {
fatalErr(t, err)
}
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
context = NewReadContext(ctx)
err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize() ser1, err := gql.Serialize()
ser2, err := u1.Serialize() ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser1)
ctx.Log.Logf("test", "\n%s\n\n", ser2) ctx.Log.Logf("test", "\n%s\n\n", ser2)
return err
})
// Clear all loaded nodes from the context so it loads them from the database // Clear all loaded nodes from the context so it loads them from the database
ctx.Nodes = NodeMap{} ctx.Nodes = NodeMap{}
gql_loaded, err := LoadNode(ctx, gql.ID) gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx)
err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
var err error
listener_ext, err = GetExt[*ListenerExt](gql_loaded) listener_ext, err = GetExt[*ListenerExt](gql_loaded)
if err != nil {
return err
}
return gql_loaded.Process(context, gql_loaded.ID, StopSignal)
})
fatalErr(t, err) fatalErr(t, err)
err = gql_loaded.Signal(ctx, gql_loaded.ID, StopSignal)
err = ThreadLoop(ctx, gql_loaded, "restore")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")

@ -38,35 +38,11 @@ func addLockableFields(object *graphql.Object, lockable_interface *graphql.Inter
}) })
} }
func AddThreadFields(object *graphql.Object) {
addThreadFields(object, GQLInterfaceThread.Interface, GQLInterfaceThread.List)
}
func addThreadFields(object *graphql.Object, thread_interface *graphql.Interface, thread_list *graphql.List) {
AddNodeFields(object)
object.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
Resolve: GQLThreadState,
})
object.AddFieldConfig("Children", &graphql.Field{
Type: thread_list,
Resolve: GQLThreadChildren,
})
object.AddFieldConfig("Parent", &graphql.Field{
Type: thread_interface,
Resolve: GQLThreadParent,
})
}
var GQLNodeInterfaces = []*graphql.Interface{GQLInterfaceNode.Interface} var GQLNodeInterfaces = []*graphql.Interface{GQLInterfaceNode.Interface}
var GQLLockableInterfaces = append(GQLNodeInterfaces, GQLInterfaceLockable.Interface) var GQLLockableInterfaces = append(GQLNodeInterfaces, GQLInterfaceLockable.Interface)
var GQLThreadInterfaces = append(GQLNodeInterfaces, GQLInterfaceThread.Interface)
var GQLTypeGQLNode = NewGQLNodeType(GQLNodeType, GQLThreadInterfaces, func(gql *GQLType) { var GQLTypeGQLNode = NewGQLNodeType(GQLNodeType, GQLNodeInterfaces, func(gql *GQLType) {
AddThreadFields(gql.Type) AddNodeFields(gql.Type)
gql.Type.AddFieldConfig("Listen", &graphql.Field{ gql.Type.AddFieldConfig("Listen", &graphql.Field{
Type: graphql.String, Type: graphql.String,

@ -60,6 +60,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
listener := NewNode(ctx, listener := NewNode(ctx,
RandID(), RandID(),
SimpleListenerNodeType, SimpleListenerNodeType,
nil,
listener_extension, listener_extension,
NewACLExt(&policy), NewACLExt(&policy),
NewLockableExt(nil, nil, nil, nil)) NewLockableExt(nil, nil, nil, nil))

@ -1,7 +1,6 @@
package graphvent package graphvent
import ( import (
"fmt"
"encoding/json" "encoding/json"
) )
@ -32,14 +31,14 @@ func (listener *ListenerExt) Type() ExtType {
return ListenerExtType return ListenerExtType
} }
func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error { func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal) ctx.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
select { select {
case ext.Chan <- signal: case ext.Chan <- signal:
default: default:
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal) ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
} }
return nil return
} }
func (ext *ListenerExt) Serialize() ([]byte, error) { func (ext *ListenerExt) Serialize() ([]byte, error) {
@ -47,10 +46,10 @@ func (ext *ListenerExt) Serialize() ([]byte, error) {
} }
type LockableExt struct { type LockableExt struct {
Owner *Node Owner *NodeID `json:"owner"`
Requirements map[NodeID]*Node Requirements []NodeID `json:"requirements"`
Dependencies map[NodeID]*Node Dependencies []NodeID `json:"dependencies"`
LocksHeld map[NodeID]*Node LocksHeld map[NodeID]*NodeID `json:"locks_held"`
} }
const LockableExtType = ExtType("LOCKABLE") const LockableExtType = ExtType("LOCKABLE")
@ -58,33 +57,13 @@ func (ext *LockableExt) Type() ExtType {
return LockableExtType return LockableExtType
} }
type LockableExtJSON struct {
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}
func (ext *LockableExt) Serialize() ([]byte, error) { func (ext *LockableExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&LockableExtJSON{ return json.MarshalIndent(ext, "", " ")
Owner: SaveNode(ext.Owner),
Requirements: SaveNodeList(ext.Requirements),
Dependencies: SaveNodeList(ext.Dependencies),
LocksHeld: SaveNodeMap(ext.LocksHeld),
}, "", " ")
}
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
if requirements == nil {
requirements = NodeMap{}
}
if dependencies == nil {
dependencies = NodeMap{}
} }
func NewLockableExt(owner *NodeID, requirements []NodeID, dependencies []NodeID, locks_held map[NodeID]*NodeID) *LockableExt {
if locks_held == nil { if locks_held == nil {
locks_held = NodeMap{} locks_held = map[NodeID]*NodeID{}
} }
return &LockableExt{ return &LockableExt{
@ -96,412 +75,71 @@ func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, loc
} }
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var j LockableExtJSON var ext LockableExt
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &ext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j) ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", ext)
owner, err := RestoreNode(ctx, j.Owner)
if err != nil {
return nil, err
}
requirements, err := RestoreNodeList(ctx, j.Requirements) return &ext, nil
if err != nil {
return nil, err
} }
dependencies, err := RestoreNodeList(ctx, j.Dependencies) func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
if err != nil { ctx.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
return nil, err
}
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
if err != nil {
return nil, err
}
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
}
func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
owner_sent := false owner_sent := false
for _, dependency := range(ext.Dependencies) { for _, dependency := range(ext.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) err := node.Signal(ctx, dependency, signal)
err := dependency.Process(context, node.ID, signal)
if err != nil { if err != nil {
return err ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err)
} }
if ext.Owner != nil { if ext.Owner != nil {
if dependency.ID == ext.Owner.ID { if dependency == *ext.Owner {
owner_sent = true owner_sent = true
} }
} }
} }
if ext.Owner != nil && owner_sent == false { if ext.Owner != nil && owner_sent == false {
if ext.Owner.ID != node.ID { if *ext.Owner != node.ID {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) err := node.Signal(ctx, *ext.Owner, signal)
err := ext.Owner.Process(context, node.ID, signal)
if err != nil { if err != nil {
return err ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, *ext.Owner, err)
} }
} }
} }
case Down: case Down:
for _, requirement := range(ext.Requirements) { for _, requirement := range(ext.Requirements) {
err := requirement.Process(context, node.ID, signal) err := node.Signal(ctx, requirement, signal)
if err != nil { if err != nil {
return err ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err)
} }
} }
case Direct: case Direct:
default: default:
return fmt.Errorf("invalid signal direction %d", signal.Direction())
} }
return nil
} }
func (ext *LockableExt) RecordUnlock(node *Node) *Node { func (ext *LockableExt) RecordUnlock(node NodeID) *NodeID {
last_owner, exists := ext.LocksHeld[node.ID] last_owner, exists := ext.LocksHeld[node]
if exists == false { if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own") panic("Attempted to take a get the original lock holder of a lockable we don't own")
} }
delete(ext.LocksHeld, node.ID) delete(ext.LocksHeld, node)
return last_owner return last_owner
} }
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) { func (ext *LockableExt) RecordLock(node NodeID, last_owner *NodeID) {
_, exists := ext.LocksHeld[node.ID] _, exists := ext.LocksHeld[node]
if exists == true { if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)") panic("Attempted to lock a lockable we're already holding(lock cycle)")
} }
ext.LocksHeld[node.ID] = last_owner ext.LocksHeld[node] = last_owner
}
// Removes requirement as a requirement from lockable
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
requirement_ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
return UpdateStates(context, princ, ACLMap{
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found *Node = nil
for _, req := range(lockable_ext.Requirements) {
if requirement.ID == req.ID {
found = req
break
}
}
if found == nil {
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
}
delete(requirement_ext.Dependencies, lockable.ID)
delete(lockable_ext.Requirements, requirement.ID)
return nil
})
}
// Link requirements as requirements to lockable
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
if lockable == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
}
if len(requirements) == 0 {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
}
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
req_exts := map[NodeID]*LockableExt{}
for _, requirement := range(requirements) {
if requirement == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
}
if lockable.ID == requirement.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
}
_, exists := req_exts[requirement.ID]
if exists == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
}
ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
req_exts[requirement.ID] = ext
}
return UpdateStates(context, princ, NewACLMap(
NewACLInfo(lockable, []string{"requirements"}),
ACLList(requirements, []string{"dependencies"}),
), func(context *StateContext) error {
// Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
for _, req := range(requirements) {
if req.ID == requirement.ID {
continue
}
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
}
}
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
}
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
}
if lockable_ext.Owner == nil {
// If the new owner isn't locked, we can add the requirement
} else if requirement_ext.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
} else {
// If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
}
}
}
// Update the states of the requirements
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
requirement_ext.Dependencies[lockable.ID] = lockable
lockable_ext.Requirements[lockable.ID] = requirement
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
}
// Return no error
return nil
})
}
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
for _, req := range(cur.Requirements) {
if req.ID == id {
return true, nil
}
req_ext, err := GetExt[*LockableExt](req)
if err != nil {
return false, err
}
var is_req bool
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
is_req, err = checkIfRequirement(context, id, req_ext)
return err
})
if err != nil {
return false, err
}
if is_req == true {
return true, nil
}
}
return false, nil
}
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
}
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_lock) {
var err error
if l == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
}
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
if new_owner == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
}
new_owner_ext, err := GetExt[*LockableExt](new_owner)
if err != nil {
return err
}
// Called with no requirements to lock, success
if len(to_lock) == 0 {
return nil
}
return UpdateStates(context, new_owner, NewACLMap(
ACLListM(to_lock, []string{"lock"}),
NewACLInfo(new_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
// If req is alreay locked, check that we can pass the lock
if req_ext.Owner != nil {
owner := req_ext.Owner
if owner.ID == new_owner.ID {
continue
} else {
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req_ext.Requirements, req)
})
if err != nil {
return err
}
}
} else {
err := LockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
}
}
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
old_owner := req_ext.Owner
// If the lockable was previously unowned, update the state
if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID == new_owner.ID {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
// Otherwise update the state
} else {
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
}
}
return nil
})
}
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
}
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_unlock) {
if l == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
}
var err error
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
if old_owner == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
}
old_owner_ext, err := GetExt[*LockableExt](old_owner)
if err != nil {
return err
}
// Called with no requirements to unlock, success
if len(to_unlock) == 0 {
return nil
}
return UpdateStates(context, old_owner, NewACLMap(
ACLListM(to_unlock, []string{"lock"}),
NewACLInfo(old_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
// Check if the owner is correct
if req_ext.Owner != nil {
if req_ext.Owner.ID != old_owner.ID {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
}
} else {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
}
err := UnlockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
}
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
new_owner := old_owner_ext.RecordUnlock(req)
req_ext.Owner = new_owner
if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
} else {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
}
}
return nil
})
} }
func SaveNode(node *Node) string { func SaveNode(node *Node) string {

@ -2,9 +2,11 @@ package graphvent
import ( import (
"sync" "sync"
"time"
"reflect" "reflect"
"github.com/google/uuid" "github.com/google/uuid"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
"runtime"
"fmt" "fmt"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
@ -20,6 +22,17 @@ func (id NodeID) MarshalJSON() ([]byte, error) {
return json.Marshal(&str) return json.Marshal(&str)
} }
func (id *NodeID) UnmarshalJSON(bytes []byte) error {
var id_str string
err := json.Unmarshal(bytes, &id_str)
if err != nil {
return err
}
*id, err = ParseID(id_str)
return err
}
var ZeroUUID = uuid.UUID{} var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID) var ZeroID = NodeID(ZeroUUID)
@ -62,20 +75,132 @@ type Serializable[I comparable] interface {
Serialize() ([]byte, error) Serialize() ([]byte, error)
} }
// NodeExtensions are additional data that can be attached to nodes, and used in node functions
type Extension interface { type Extension interface {
Serializable[ExtType] Serializable[ExtType]
// Send a signal to this extension to process, Process(context *Context, source NodeID, node *Node, signal Signal)
// this typically triggers signals to be sent to nodes linked in the extension }
Process(context *StateContext, node *Node, signal Signal) error
type QueuedSignal struct {
Signal Signal
Time time.Time
} }
const NODE_MSG_CHAN_DEFAULT = 1024
// Nodes represent an addressible group of extensions // Nodes represent an addressible group of extensions
type Node struct { type Node struct {
ID NodeID ID NodeID
Type NodeType Type NodeType
Lock sync.RWMutex Lock sync.RWMutex
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
MsgChan chan Msg
TimeoutChan <-chan time.Time
LoopLock sync.Mutex
Active bool
SignalQueue []QueuedSignal
NextSignal *QueuedSignal
}
func (node *Node) QueueSignal(time time.Time, signal Signal) {
node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time})
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
}
func (node *Node) ClearSignalQueue() {
node.SignalQueue = []QueuedSignal{}
node.NextSignal = nil
node.TimeoutChan = nil
}
func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) {
var soonest_signal *QueuedSignal
var soonest_time time.Time
for _, signal := range(signals) {
if signal.Time.Compare(soonest_time) == -1 || soonest_signal == nil {
soonest_signal = &signal
soonest_time = signal.Time
}
}
if soonest_signal != nil {
return soonest_signal, time.After(time.Until(soonest_time))
} else {
return nil, nil
}
}
func RunNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_START: %s", node.ID)
err := NodeLoop(ctx, node)
if err != nil {
panic(err)
}
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
}
type Msg struct {
Source NodeID
Signal Signal
}
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func NodeLoop(ctx *Context, node *Node) error {
node.LoopLock.Lock()
defer node.LoopLock.Unlock()
node.Active = true
for true {
var signal Signal
var source NodeID
select {
case msg := <- node.MsgChan:
signal = msg.Signal
source = msg.Source
err := Allowed(ctx, msg.Source, string(signal.Type()), node)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
continue
}
case <-node.TimeoutChan:
signal = node.NextSignal.Signal
source = node.ID
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
ctx.Log.Logf("node", "NODE_TIMEOUT %s - NEXT_SIGNAL: %s", node.ID, signal)
}
// Handle special signal types
if signal.Type() == StopSignalType {
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID))
break
}
node.Process(ctx, source, signal)
}
return nil
}
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) {
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("signal", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
ext.Process(ctx, source, node, signal)
}
}
func (node *Node) Signal(ctx *Context, dest NodeID, signal Signal) error {
target, exists := ctx.Nodes[dest]
if exists == false {
return fmt.Errorf("%s does not exist, cannot signal it", dest)
}
select {
case target.MsgChan <- Msg{node.ID, signal}:
default:
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
stack_str := string(buf[:n])
return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str)
}
return nil
} }
func GetCtx[T Extension, C any](ctx *Context) (C, error) { func GetCtx[T Extension, C any](ctx *Context) (C, error) {
@ -118,8 +243,10 @@ func (node *Node) Serialize() ([]byte, error) {
Magic: NODE_DB_MAGIC, Magic: NODE_DB_MAGIC,
TypeHash: node.Type.Hash(), TypeHash: node.Type.Hash(),
NumExtensions: uint32(len(extensions)), NumExtensions: uint32(len(extensions)),
NumQueuedSignals: uint32(len(node.SignalQueue)),
}, },
Extensions: extensions, Extensions: extensions,
QueuedSignals: node.SignalQueue,
} }
i := 0 i := 0
@ -141,7 +268,8 @@ func (node *Node) Serialize() ([]byte, error) {
return node_db.Serialize(), nil return node_db.Serialize(), nil
} }
func NewNode(ctx *Context, id NodeID, node_type NodeType, extensions ...Extension) *Node { // Create a new node in memory and start it's event loop
func NewNode(ctx *Context, id NodeID, node_type NodeType, queued_signals []QueuedSignal, extensions ...Extension) *Node {
_, exists := ctx.Nodes[id] _, exists := ctx.Nodes[id]
if exists == true { if exists == true {
panic("Attempted to create an existing node") panic("Attempted to create an existing node")
@ -168,18 +296,31 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, extensions ...Extensio
} }
} }
if queued_signals == nil {
queued_signals = []QueuedSignal{}
}
next_signal, timeout_chan := SoonestSignal(queued_signals)
node := &Node{ node := &Node{
ID: id, ID: id,
Type: node_type, Type: node_type,
Extensions: ext_map, Extensions: ext_map,
MsgChan: make(chan Msg, NODE_MSG_CHAN_DEFAULT),
TimeoutChan: timeout_chan,
SignalQueue: queued_signals,
NextSignal: next_signal,
} }
ctx.Nodes[id] = node ctx.Nodes[id] = node
WriteNode(ctx, node)
go RunNode(ctx, node)
return node return node
} }
func Allowed(context *StateContext, principal_id NodeID, action string, node *Node) error { func Allowed(ctx *Context, principal_id NodeID, action string, node *Node) error {
context.Graph.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) ctx.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action)
// Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension
if principal_id == node.ID { if principal_id == node.ID {
return nil return nil
@ -191,43 +332,24 @@ func Allowed(context *StateContext, principal_id NodeID, action string, node *No
return err return err
} }
return policy_ext.Allows(context, principal_id, action, node) return policy_ext.Allows(ctx, principal_id, action, node)
}
// Check that princ is allowed to signal this action,
// then send the signal to all the extensions of the node
func (node *Node) Process(context *StateContext, princ_id NodeID, signal Signal) error {
ser, _ := signal.Serialize()
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, string(ser))
err := Allowed(context, princ_id, fmt.Sprintf("signal.%s", signal.Type()), node)
if err != nil {
return err
}
for ext_type, ext := range(node.Extensions) {
err = ext.Process(context, node, signal)
if err != nil {
context.Graph.Log.Logf("signal", "EXTENSION_SIGNAL_ERR: %s/%s - %s", node.ID, ext_type, err)
}
}
return nil
} }
// Magic first four bytes of serialized DB content, stored big endian // Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14 const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
const NODE_DB_HEADER_LEN = 16 const NODE_DB_HEADER_LEN = 20
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct { type NodeDBHeader struct {
Magic uint32 Magic uint32
NumExtensions uint32 NumExtensions uint32
NumQueuedSignals uint32
TypeHash uint64 TypeHash uint64
} }
type NodeDB struct { type NodeDB struct {
Header NodeDBHeader Header NodeDBHeader
QueuedSignals []QueuedSignal
Extensions []ExtensionDB Extensions []ExtensionDB
} }
@ -239,7 +361,8 @@ func NewNodeDB(data []byte) (NodeDB, error) {
magic := binary.BigEndian.Uint32(data[0:4]) magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8]) num_extensions := binary.BigEndian.Uint32(data[4:8])
node_type_hash := binary.BigEndian.Uint64(data[8:16]) num_queued_signals := binary.BigEndian.Uint32(data[8:12])
node_type_hash := binary.BigEndian.Uint64(data[12:20])
ptr += NODE_DB_HEADER_LEN ptr += NODE_DB_HEADER_LEN
@ -269,13 +392,20 @@ func NewNodeDB(data []byte) (NodeDB, error) {
ptr += int(EXTENSION_DB_HEADER_LEN + length) ptr += int(EXTENSION_DB_HEADER_LEN + length)
} }
queued_signals := make([]QueuedSignal, num_queued_signals)
for i, _ := range(queued_signals) {
queued_signals[i] = QueuedSignal{}
}
return NodeDB{ return NodeDB{
Header: NodeDBHeader{ Header: NodeDBHeader{
Magic: magic, Magic: magic,
TypeHash: node_type_hash, TypeHash: node_type_hash,
NumExtensions: num_extensions, NumExtensions: num_extensions,
NumQueuedSignals: num_queued_signals,
}, },
Extensions: extensions, Extensions: extensions,
QueuedSignals: queued_signals,
}, nil }, nil
} }
@ -287,7 +417,8 @@ func (header NodeDBHeader) Serialize() []byte {
ret := make([]byte, NODE_DB_HEADER_LEN) ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash) binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals)
binary.BigEndian.PutUint64(ret[12:20], header.TypeHash)
return ret return ret
} }
@ -324,6 +455,20 @@ type ExtensionDB struct {
} }
// Write multiple nodes to the database in a single transaction // Write multiple nodes to the database in a single transaction
func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
bytes, err := node.Serialize()
if err != nil {
return err
}
id_bytes := node.ID.Serialize()
return ctx.DB.Update(func(txn *badger.Txn) error {
return txn.Set(id_bytes, bytes)
})
}
func WriteNodes(context *StateContext) error { func WriteNodes(context *StateContext) error {
err := ValidateStateContext(context, "write", true) err := ValidateStateContext(context, "write", true)
if err != nil { if err != nil {
@ -368,10 +513,13 @@ func WriteNodes(context *StateContext) error {
// Recursively load a node from the database. // Recursively load a node from the database.
func LoadNode(ctx * Context, id NodeID) (*Node, error) { func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOOKING_FOR_NODE: %s", id)
node, exists := ctx.Nodes[id] node, exists := ctx.Nodes[id]
if exists == true { if exists == true {
ctx.Log.Logf("db", "NODE_ALREADY_LOADED: %s", id)
return node,nil return node,nil
} }
ctx.Log.Logf("db", "LOADING_NODE: %s", id)
var bytes []byte var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error { err := ctx.DB.View(func(txn *badger.Txn) error {
@ -400,10 +548,15 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
} }
next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals)
node = &Node{ node = &Node{
ID: id, ID: id,
Type: node_type.Type, Type: node_type.Type,
Extensions: map[ExtType]Extension{}, Extensions: map[ExtType]Extension{},
MsgChan: make(chan Msg, NODE_MSG_CHAN_DEFAULT),
TimeoutChan: timeout_chan,
SignalQueue: node_db.QueuedSignals,
NextSignal: next_signal,
} }
ctx.Nodes[id] = node ctx.Nodes[id] = node
@ -462,6 +615,9 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
} }
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
go RunNode(ctx, node)
return node, nil return node, nil
} }
@ -605,197 +761,3 @@ func del[K comparable](list []K, val K) []K {
list[idx] = list[len(list)-1] list[idx] = list[len(list)-1]
return list[:len(list)-1] return list[:len(list)-1]
} }
// Add nodes to an existing read context and call nodes_fn with new_nodes locked for read
// Check that the node has read permissions for the nodes, then add them to the read context and call nodes_fn with the nodes locked for read
func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UseStates")
}
err := ValidateStateContext(context, "read", false)
if err != nil {
return err
}
if context.Started == false {
context.Started = true
}
new_locks := []*Node{}
_, princ_locked := context.Locked[principal.ID]
if princ_locked == false {
new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "RLOCKING_PRINC %s", principal.ID.String())
principal.Lock.RLock()
}
princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := ACLMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) {
node := request.Node
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID
if id != principal.ID {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "RLOCKING %s", id.String())
node.Lock.RLock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = ACLInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
already_granted := false
for _, granted := range(node_permissions.Resources) {
if resource == granted {
already_granted = true
}
}
if already_granted == false {
err := Allowed(context, principal.ID, fmt.Sprintf("%s.read", resource), node)
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String())
n.Lock.RUnlock()
}
return err
}
}
}
new_permissions[id] = node_permissions
}
for _, node := range(new_locks) {
context.Locked[node.ID] = node
}
context.Permissions[principal.ID] = new_permissions
err = state_fn(context)
context.Permissions[principal.ID] = princ_permissions
for _, node := range(new_locks) {
context.Graph.Log.Logf("mutex", "RUNLOCKING %s", node.ID.String())
delete(context.Locked, node.ID)
node.Lock.RUnlock()
}
return err
}
// Add nodes to an existing write context and call nodes_fn with nodes locked for read
// If context is nil
func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, state_fn StateFn) error {
if principal == nil || new_nodes == nil || state_fn == nil {
return fmt.Errorf("nil passed to UpdateStates")
}
err := ValidateStateContext(context, "write", false)
if err != nil {
return err
}
final := false
if context.Started == false {
context.Started = true
final = true
}
new_locks := []*Node{}
_, princ_locked := context.Locked[principal.ID]
if princ_locked == false {
new_locks = append(new_locks, principal)
context.Graph.Log.Logf("mutex", "LOCKING_PRINC %s", principal.ID.String())
principal.Lock.Lock()
}
princ_permissions, princ_exists := context.Permissions[principal.ID]
new_permissions := ACLMap{}
if princ_exists == true {
for id, info := range(princ_permissions) {
new_permissions[id] = info
}
}
for _, request := range(new_nodes) {
node := request.Node
if node == nil {
return fmt.Errorf("node in request list is nil")
}
id := node.ID
if id != principal.ID {
_, locked := context.Locked[id]
if locked == false {
new_locks = append(new_locks, node)
context.Graph.Log.Logf("mutex", "LOCKING %s", id.String())
node.Lock.Lock()
}
}
node_permissions, node_exists := new_permissions[id]
if node_exists == false {
node_permissions = ACLInfo{Node: node, Resources: []string{}}
}
for _, resource := range(request.Resources) {
already_granted := false
for _, granted := range(node_permissions.Resources) {
if resource == granted {
already_granted = true
}
}
if already_granted == false {
err := Allowed(context, principal.ID, fmt.Sprintf("%s.write", resource), node)
if err != nil {
for _, n := range(new_locks) {
context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String())
n.Lock.Unlock()
}
return err
}
}
}
new_permissions[id] = node_permissions
}
for _, node := range(new_locks) {
context.Locked[node.ID] = node
}
context.Permissions[principal.ID] = new_permissions
err = state_fn(context)
if final == true {
context.Finished = true
if err == nil {
err = WriteNodes(context)
}
for id, node := range(context.Locked) {
context.Graph.Log.Logf("mutex", "UNLOCKING %s", id.String())
node.Lock.Unlock()
}
}
return err
}

@ -10,14 +10,10 @@ func TestNodeDB(t *testing.T) {
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
node := NewNode(ctx, RandID(), node_type, NewGroupExt(nil)) node := NewNode(ctx, RandID(), node_type, nil, NewGroupExt(nil))
context := NewWriteContext(ctx)
err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error {
ser, err := node.Serialize() ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %+v", ser) ctx.Log.Logf("test", "NODE_SER: %+v", ser)
return err
})
fatalErr(t, err) fatalErr(t, err)
ctx.Nodes = NodeMap{} ctx.Nodes = NodeMap{}

@ -7,15 +7,15 @@ import (
type Policy interface { type Policy interface {
Serializable[PolicyType] Serializable[PolicyType]
Allows(context *StateContext, principal_id NodeID, action string, node *Node) error Allows(principal_id NodeID, action string, node *Node) error
} }
//TODO: Update with change from principal *Node to principal_id so sane policies can still be made //TODO: Update with change from principal *Node to principal_id so sane policies can still be made
func (policy *AllNodesPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { func (policy *AllNodesPolicy) Allows(principal_id NodeID, action string, node *Node) error {
return policy.Actions.Allows(action) return policy.Actions.Allows(action)
} }
func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { func (policy *PerNodePolicy) Allows(principal_id NodeID, action string, node *Node) error {
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
if id != principal_id { if id != principal_id {
continue continue
@ -29,13 +29,13 @@ func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID,
return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID)
} }
func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action string, node *Node) error {
lockable_ext, err := GetExt[*LockableExt](node) lockable_ext, err := GetExt[*LockableExt](node)
if err != nil { if err != nil {
return err return err
} }
for id, _ := range(lockable_ext.Requirements) { for _, id := range(lockable_ext.Requirements) {
if id == principal_id { if id == principal_id {
return policy.Actions.Allows(action) return policy.Actions.Allows(action)
} }
@ -44,36 +44,6 @@ func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id No
return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID)
} }
func (policy *ParentOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error {
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return err
}
if thread_ext.Parent != nil {
if thread_ext.Parent.ID == principal_id {
return policy.Actions.Allows(action)
}
}
return fmt.Errorf("%s is not a parent of %s", principal_id, node.ID)
}
func (policy *ChildOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error {
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return err
}
for id, _ := range(thread_ext.Children) {
if id == principal_id {
return policy.Actions.Allows(action)
}
}
return fmt.Errorf("%s is not a child of %s", principal_id, node.ID)
}
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
type RequirementOfPolicy struct { type RequirementOfPolicy struct {
AllNodesPolicy AllNodesPolicy
@ -88,14 +58,6 @@ func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy {
} }
} }
const ChildOfPolicyType = PolicyType("CHILD_OF")
type ChildOfPolicy struct {
AllNodesPolicy
}
func (policy *ChildOfPolicy) Type() PolicyType {
return ChildOfPolicyType
}
type Actions []string type Actions []string
func (actions Actions) Allows(action string) error { func (actions Actions) Allows(action string) error {
@ -153,26 +115,6 @@ func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context,
} }
} }
func NewChildOfPolicy(actions Actions) ChildOfPolicy {
return ChildOfPolicy{
AllNodesPolicy: NewAllNodesPolicy(actions),
}
}
const ParentOfPolicyType = PolicyType("PARENT_OF")
type ParentOfPolicy struct {
AllNodesPolicy
}
func (policy *ParentOfPolicy) Type() PolicyType {
return ParentOfPolicyType
}
func NewParentOfPolicy(actions Actions) ParentOfPolicy {
return ParentOfPolicy{
AllNodesPolicy: NewAllNodesPolicy(actions),
}
}
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
if node_actions == nil { if node_actions == nil {
node_actions = NodeActions{} node_actions = NodeActions{}
@ -268,18 +210,6 @@ func NewACLExtContext() *ACLExtContext {
return &policy, nil return &policy, nil
}), }),
}, },
ParentOfPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewParentOfPolicy(actions)
return &policy, nil
}),
},
ChildOfPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewChildOfPolicy(actions)
return &policy, nil
}),
},
RequirementOfPolicyType: PolicyInfo{ RequirementOfPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewRequirementOfPolicy(actions) policy := NewRequirementOfPolicy(actions)
@ -307,8 +237,7 @@ func (ext *ACLExt) Serialize() ([]byte, error) {
}, "", " ") }, "", " ")
} }
func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error { func (ext *ACLExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
return nil
} }
func NewACLExt(policies ...Policy) *ACLExt { func NewACLExt(policies ...Policy) *ACLExt {
@ -362,11 +291,11 @@ func (ext *ACLExt) Type() ExtType {
} }
// Check if the extension allows the principal to perform action on node // Check if the extension allows the principal to perform action on node
func (ext *ACLExt) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action string, node *Node) error {
context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) ctx.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
errs := []error{} errs := []error{}
for _, policy := range(ext.Policies) { for _, policy := range(ext.Policies) {
err := policy.Allows(context, principal_id, action, node) err := policy.Allows(principal_id, action, node)
if err == nil { if err == nil {
return nil return nil
} }

@ -55,8 +55,8 @@ func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct) return NewBaseSignal(signal_type, Direct)
} }
var AbortSignal = NewBaseSignal("abort", Down) const StopSignalType = SignalType("STOP")
var StopSignal = NewBaseSignal("stop", Down) var StopSignal = NewDownSignal(StopSignalType)
type IDSignal struct { type IDSignal struct {
BaseSignal BaseSignal

@ -20,8 +20,8 @@ type ECDHExtJSON struct {
Shared []byte `json:"shared"` Shared []byte `json:"shared"`
} }
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error { func (ext *ECDHExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
return nil return
} }
const ECDHExtType = ExtType("ECDH") const ECDHExtType = ExtType("ECDH")
@ -115,6 +115,6 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
return NewGroupExt(members), nil return NewGroupExt(members), nil
} }
func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error { func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
return nil return
} }