Compare commits
No commits in common. "master" and "graph-rework-2" have entirely different histories.
master
...
graph-rewo
@ -1,3 +0,0 @@
|
||||
[submodule "graphql"]
|
||||
path = graphql
|
||||
url = https://github.com/graphql-go/graphql
|
@ -1,48 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
badger "github.com/dgraph-io/badger/v3"
|
||||
gv "github.com/mekkanized/graphvent"
|
||||
)
|
||||
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
|
||||
check(err)
|
||||
|
||||
ctx, err := gv.NewContext(&gv.BadgerDB{
|
||||
DB: db,
|
||||
}, gv.NewConsoleLogger([]string{"test"}))
|
||||
check(err)
|
||||
|
||||
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
|
||||
check(err)
|
||||
|
||||
listener_ext := gv.NewListenerExt(1000)
|
||||
|
||||
n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
||||
check(err)
|
||||
|
||||
n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
|
||||
check(err)
|
||||
|
||||
n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
||||
check(err)
|
||||
|
||||
_, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID}))
|
||||
check(err)
|
||||
|
||||
for true {
|
||||
select {
|
||||
case message := <- listener_ext.Chan:
|
||||
fmt.Printf("Listener Message: %+v\n", message)
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,284 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
badger "github.com/dgraph-io/badger/v3"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
WriteNodeInit(*Context, *Node) error
|
||||
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
|
||||
LoadNode(*Context, NodeID) (*Node, error)
|
||||
}
|
||||
|
||||
const WRITE_BUFFER_SIZE = 1000000
|
||||
type BadgerDB struct {
|
||||
*badger.DB
|
||||
sync.Mutex
|
||||
buffer [WRITE_BUFFER_SIZE]byte
|
||||
}
|
||||
|
||||
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
|
||||
if node == nil {
|
||||
return fmt.Errorf("Cannot serialize nil *Node")
|
||||
}
|
||||
|
||||
return db.Update(func(tx *badger.Txn) error {
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
// Get the base key bytes
|
||||
id_ser, err := node.ID.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cur := 0
|
||||
|
||||
// Write Node value
|
||||
written, err := Serialize(ctx, node, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Set(id_ser, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cur += written
|
||||
|
||||
// Write empty signal queue
|
||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||
written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cur += written
|
||||
|
||||
// Write node extension list
|
||||
ext_list := []ExtType{}
|
||||
for ext_type := range(node.Extensions) {
|
||||
ext_list = append(ext_list, ext_type)
|
||||
}
|
||||
written, err = Serialize(ctx, ext_list, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||
err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cur += written
|
||||
|
||||
// For each extension:
|
||||
for ext_type, ext := range(node.Extensions) {
|
||||
ext_info, exists := ctx.Extensions[ext_type]
|
||||
if exists == false {
|
||||
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
|
||||
}
|
||||
|
||||
ext_value := reflect.ValueOf(ext).Elem()
|
||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||
|
||||
// Write each field to a seperate key
|
||||
for field_tag, field_info := range(ext_info.Fields) {
|
||||
field_value := ext_value.FieldByIndex(field_info.Index)
|
||||
|
||||
field_id := make([]byte, len(ext_id) + 8)
|
||||
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
||||
copy(field_id, tmp)
|
||||
|
||||
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
||||
}
|
||||
|
||||
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
||||
}
|
||||
cur += written
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
|
||||
return db.Update(func(tx *badger.Txn) error {
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
// Get the base key bytes
|
||||
id_bytes := ([16]byte)(node.ID)
|
||||
|
||||
cur := 0
|
||||
|
||||
// Write the signal queue if it needs to be written
|
||||
if node.writeSignalQueue {
|
||||
node.writeSignalQueue = false
|
||||
|
||||
sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
|
||||
written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
|
||||
}
|
||||
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
|
||||
}
|
||||
cur += written
|
||||
}
|
||||
|
||||
// For each ext in changes
|
||||
for ext_type, changes := range(changes) {
|
||||
ext_info, exists := ctx.Extensions[ext_type]
|
||||
if exists == false {
|
||||
return fmt.Errorf("%s is not an extension in ctx", ext_type)
|
||||
}
|
||||
|
||||
ext, exists := node.Extensions[ext_type]
|
||||
if exists == false {
|
||||
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
|
||||
}
|
||||
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
|
||||
ext_value := reflect.ValueOf(ext)
|
||||
|
||||
// Write each field
|
||||
for _, tag := range(changes) {
|
||||
field_info, exists := ext_info.Fields[tag]
|
||||
if exists == false {
|
||||
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
|
||||
}
|
||||
|
||||
field_value := ext_value.FieldByIndex(field_info.Index)
|
||||
|
||||
field_id := make([]byte, len(ext_id) + 8)
|
||||
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
|
||||
copy(field_id, tmp)
|
||||
|
||||
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
||||
}
|
||||
|
||||
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
||||
}
|
||||
cur += written
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
|
||||
var node *Node = nil
|
||||
|
||||
err := db.View(func(tx *badger.Txn) error {
|
||||
// Get the base key bytes
|
||||
id_ser, err := id.MarshalBinary()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to serialize node_id: %w", err)
|
||||
}
|
||||
|
||||
// Get the node value
|
||||
node_item, err := tx.Get(id_ser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
|
||||
}
|
||||
|
||||
err = node_item.Value(func(val []byte) error {
|
||||
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
|
||||
node, err = Deserialize[*Node](ctx, val)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
|
||||
}
|
||||
|
||||
// Get the signal queue
|
||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||
sigqueue_item, err := tx.Get(sigqueue_id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get sigqueue_id: %w", err)
|
||||
}
|
||||
err = sigqueue_item.Value(func(val []byte) error {
|
||||
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
|
||||
}
|
||||
|
||||
// Get the extension list
|
||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||
ext_list_item, err := tx.Get(ext_list_id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ext_list []ExtType
|
||||
ext_list_item.Value(func(val []byte) error {
|
||||
ext_list, err = Deserialize[[]ExtType](ctx, val)
|
||||
return err
|
||||
})
|
||||
|
||||
// Get the extensions
|
||||
for _, ext_type := range(ext_list) {
|
||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||
ext_info, exists := ctx.Extensions[ext_type]
|
||||
if exists == false {
|
||||
return fmt.Errorf("Extension %s not in context", ext_type)
|
||||
}
|
||||
|
||||
ext := reflect.New(ext_info.Type)
|
||||
for field_tag, field_info := range(ext_info.Fields) {
|
||||
field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
||||
field_item, err := tx.Get(field_id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
|
||||
}
|
||||
err = field_item.Value(func(val []byte) error {
|
||||
value, _, err := DeserializeValue(ctx, val, field_info.Type)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ext.Elem().FieldByIndex(field_info.Index).Set(value)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
node.Extensions[ext_type] = ext.Interface().(Extension)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if node == nil {
|
||||
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
type Tag string
|
||||
type Changes []Tag
|
||||
|
||||
// Extensions are data attached to nodes that process signals
|
||||
type Extension interface {
|
||||
// Called to process incoming signals, returning changes and messages to send
|
||||
Process(*Context, *Node, NodeID, Signal) ([]Message, Changes)
|
||||
|
||||
// Called when the node is loaded into a context(creation or move), so extension data can be initialized
|
||||
Load(*Context, *Node) error
|
||||
|
||||
// Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up
|
||||
Unload(*Context, *Node)
|
||||
}
|
@ -0,0 +1,172 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"github.com/graphql-go/graphql"
|
||||
"reflect"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NewField(init func()*graphql.Field) *graphql.Field {
|
||||
return init()
|
||||
}
|
||||
|
||||
type Singleton[K graphql.Type] struct {
|
||||
Type K
|
||||
List *graphql.List
|
||||
}
|
||||
|
||||
func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List)) *Singleton[K] {
|
||||
val := init()
|
||||
list := graphql.NewList(val)
|
||||
if post_init != nil {
|
||||
post_init(val, list)
|
||||
}
|
||||
return &Singleton[K]{
|
||||
Type: val,
|
||||
List: list,
|
||||
}
|
||||
}
|
||||
|
||||
func AddNodeInterfaceFields(i *graphql.Interface) {
|
||||
i.AddFieldConfig("ID", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
})
|
||||
|
||||
i.AddFieldConfig("TypeHash", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
})
|
||||
}
|
||||
|
||||
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
|
||||
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
||||
if ok == false {
|
||||
return nil, fmt.Errorf("Bad resolve in params context")
|
||||
}
|
||||
return resolve_context, nil
|
||||
}
|
||||
|
||||
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
|
||||
i := graphql.NewInterface(graphql.InterfaceConfig{
|
||||
Name: "Node",
|
||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||
ctx, err := PrepTypeResolve(p)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
valid_nodes := ctx.GQLContext.ValidNodes
|
||||
p_type := reflect.TypeOf(p.Value)
|
||||
|
||||
for key, value := range(valid_nodes) {
|
||||
if p_type == key {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := p.Value.(Node)
|
||||
if ok == true {
|
||||
return ctx.GQLContext.BaseNodeType
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
AddNodeInterfaceFields(i)
|
||||
|
||||
return i
|
||||
}, nil)
|
||||
|
||||
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
|
||||
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
|
||||
Name: "Lockable",
|
||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||
ctx, err := PrepTypeResolve(p)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
valid_lockables := ctx.GQLContext.ValidLockables
|
||||
p_type := reflect.TypeOf(p.Value)
|
||||
|
||||
for key, value := range(valid_lockables) {
|
||||
if p_type == key {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := p.Value.(*Node)
|
||||
if ok == false {
|
||||
return ctx.GQLContext.BaseLockableType
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
return gql_interface_lockable
|
||||
}, func(lockable *graphql.Interface, lockable_list *graphql.List) {
|
||||
lockable.AddFieldConfig("Requirements", &graphql.Field{
|
||||
Type: lockable_list,
|
||||
})
|
||||
|
||||
lockable.AddFieldConfig("Dependencies", &graphql.Field{
|
||||
Type: lockable_list,
|
||||
})
|
||||
|
||||
lockable.AddFieldConfig("Owner", &graphql.Field{
|
||||
Type: lockable,
|
||||
})
|
||||
AddNodeInterfaceFields(lockable)
|
||||
})
|
||||
|
||||
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
|
||||
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
|
||||
Name: "Thread",
|
||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||
ctx, err := PrepTypeResolve(p)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
valid_threads := ctx.GQLContext.ValidThreads
|
||||
p_type := reflect.TypeOf(p.Value)
|
||||
|
||||
for key, value := range(valid_threads) {
|
||||
if p_type == key {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
node, ok := p.Value.(*Node)
|
||||
if ok == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = GetExt[*ThreadExt](node)
|
||||
if err == nil {
|
||||
return ctx.GQLContext.BaseThreadType
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
return gql_interface_thread
|
||||
}, func(thread *graphql.Interface, thread_list *graphql.List) {
|
||||
thread.AddFieldConfig("Children", &graphql.Field{
|
||||
Type: thread_list,
|
||||
})
|
||||
|
||||
thread.AddFieldConfig("Parent", &graphql.Field{
|
||||
Type: thread,
|
||||
})
|
||||
|
||||
thread.AddFieldConfig("State", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
})
|
||||
|
||||
AddNodeInterfaceFields(thread)
|
||||
})
|
@ -0,0 +1,114 @@
|
||||
package graphvent
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/graphql-go/graphql"
|
||||
)
|
||||
|
||||
var GQLMutationAbort = NewField(func()*graphql.Field {
|
||||
gql_mutation_abort := &graphql.Field{
|
||||
Type: GQLTypeSignal.Type,
|
||||
Args: graphql.FieldConfigArgument{
|
||||
"id": &graphql.ArgumentConfig{
|
||||
Type: graphql.String,
|
||||
},
|
||||
},
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
_, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := ExtractID(p, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var node *Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLMap(
|
||||
NewACLInfo(ctx.Server, []string{"children"}),
|
||||
), func(context *StateContext) (error){
|
||||
node, err = FindChild(context, ctx.User, ctx.Server, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node == nil {
|
||||
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
|
||||
}
|
||||
return SendSignal(context, node, ctx.User, AbortSignal)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return AbortSignal, nil
|
||||
},
|
||||
}
|
||||
|
||||
return gql_mutation_abort
|
||||
})
|
||||
|
||||
var GQLMutationStartChild = NewField(func()*graphql.Field{
|
||||
gql_mutation_start_child := &graphql.Field{
|
||||
Type: GQLTypeSignal.Type,
|
||||
Args: graphql.FieldConfigArgument{
|
||||
"parent_id": &graphql.ArgumentConfig{
|
||||
Type: graphql.String,
|
||||
},
|
||||
"child_id": &graphql.ArgumentConfig{
|
||||
Type: graphql.String,
|
||||
},
|
||||
"action": &graphql.ArgumentConfig{
|
||||
Type: graphql.String,
|
||||
DefaultValue: "start",
|
||||
},
|
||||
},
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
_, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parent_id, err := ExtractID(p, "parent_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
child_id, err := ExtractID(p, "child_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
action, err := ExtractParam[string](p, "action")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signal Signal
|
||||
context := NewWriteContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLMap(
|
||||
NewACLInfo(ctx.Server, []string{"children"}),
|
||||
), func(context *StateContext) error {
|
||||
parent, err := FindChild(context, ctx.User, ctx.Server, parent_id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if parent == nil {
|
||||
return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID)
|
||||
}
|
||||
|
||||
signal = NewStartChildSignal(child_id, action)
|
||||
return SendSignal(context, ctx.User, parent, signal)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: wait for the result of the signal to send back instead of just the signal
|
||||
return signal, nil
|
||||
},
|
||||
}
|
||||
|
||||
return gql_mutation_start_child
|
||||
})
|
||||
|
@ -1,147 +0,0 @@
|
||||
package graphvent
|
||||
import (
|
||||
"reflect"
|
||||
"fmt"
|
||||
"time"
|
||||
"github.com/graphql-go/graphql"
|
||||
"github.com/graphql-go/graphql/language/ast"
|
||||
)
|
||||
|
||||
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ok := p.Source.(NodeResult)
|
||||
if ok == false {
|
||||
return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source))
|
||||
}
|
||||
|
||||
return node.NodeID, nil
|
||||
}
|
||||
|
||||
func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ok := p.Source.(NodeResult)
|
||||
if ok == false {
|
||||
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
|
||||
}
|
||||
|
||||
return uint64(node.NodeType), nil
|
||||
}
|
||||
|
||||
type FieldIndex struct {
|
||||
Extension ExtType
|
||||
Tag string
|
||||
}
|
||||
|
||||
func GetFields(selection_set *ast.SelectionSet) []string {
|
||||
names := []string{}
|
||||
if selection_set == nil {
|
||||
return names
|
||||
}
|
||||
|
||||
for _, sel := range(selection_set.Selections) {
|
||||
switch field := sel.(type) {
|
||||
case *ast.Field:
|
||||
if field.Name.Value == "ID" || field.Name.Value == "Type" {
|
||||
continue
|
||||
}
|
||||
names = append(names, field.Name.Value)
|
||||
case *ast.InlineFragment:
|
||||
names = append(names, GetFields(field.SelectionSet)...)
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
// Returns the fields that need to be resolved
|
||||
func GetResolveFields(p graphql.ResolveParams) []string {
|
||||
fields := []string{}
|
||||
for _, field := range(p.Info.FieldASTs) {
|
||||
fields = append(fields, GetFields(field.SelectionSet)...)
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
|
||||
ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return NodeResult{}, err
|
||||
}
|
||||
|
||||
switch source := p.Source.(type) {
|
||||
case *StatusSignal:
|
||||
cached_node, cached := ctx.NodeCache[source.Source]
|
||||
if cached {
|
||||
for _, field_name := range(source.Fields) {
|
||||
_, cached := cached_node.Data[field_name]
|
||||
if cached {
|
||||
delete(cached_node.Data, field_name)
|
||||
}
|
||||
}
|
||||
ctx.NodeCache[source.Source] = cached_node
|
||||
}
|
||||
}
|
||||
|
||||
cache, node_cached := ctx.NodeCache[id]
|
||||
fields := GetResolveFields(p)
|
||||
var not_cached []string
|
||||
if node_cached {
|
||||
not_cached = []string{}
|
||||
for _, field := range(fields) {
|
||||
if node_cached {
|
||||
_, field_cached := cache.Data[field]
|
||||
if field_cached {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
not_cached = append(not_cached, field)
|
||||
}
|
||||
} else {
|
||||
not_cached = fields
|
||||
}
|
||||
|
||||
if (len(not_cached) == 0) && (node_cached == true) {
|
||||
ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
|
||||
return cache, nil
|
||||
} else {
|
||||
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
|
||||
|
||||
signal := NewReadSignal(not_cached)
|
||||
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
||||
// TODO: TIMEOUT DURATION
|
||||
err = ctx.Context.Send(ctx.Server, []Message{{
|
||||
Node: id,
|
||||
Signal: signal,
|
||||
}})
|
||||
if err != nil {
|
||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||
return NodeResult{}, err
|
||||
}
|
||||
|
||||
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||
if err != nil {
|
||||
return NodeResult{}, err
|
||||
}
|
||||
|
||||
switch response := response.(type) {
|
||||
case *ReadResultSignal:
|
||||
if node_cached == false {
|
||||
cache = NodeResult{
|
||||
NodeID: id,
|
||||
NodeType: response.NodeType,
|
||||
Data: response.Fields,
|
||||
}
|
||||
} else {
|
||||
for field_name, field_value := range(response.Fields) {
|
||||
cache.Data[field_name] = field_value
|
||||
}
|
||||
}
|
||||
|
||||
ctx.NodeCache[id] = cache
|
||||
return ctx.NodeCache[id], nil
|
||||
default:
|
||||
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package graphvent
|
||||
import (
|
||||
"github.com/graphql-go/graphql"
|
||||
)
|
||||
|
||||
var GQLQuerySelf = &graphql.Field{
|
||||
Type: GQLTypeBaseThread.Type,
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
_, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ctx.Server, nil
|
||||
},
|
||||
}
|
||||
|
||||
var GQLQueryUser = &graphql.Field{
|
||||
Type: GQLTypeBaseNode.Type,
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
_, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ctx.User, nil
|
||||
},
|
||||
}
|
@ -0,0 +1,338 @@
|
||||
package graphvent
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"github.com/graphql-go/graphql"
|
||||
)
|
||||
|
||||
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
|
||||
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
||||
if ok == false {
|
||||
return nil, nil, fmt.Errorf("Bad resolve in params context")
|
||||
}
|
||||
|
||||
node, ok := p.Source.(*Node)
|
||||
if ok == false {
|
||||
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
|
||||
}
|
||||
|
||||
return node, resolve_context, nil
|
||||
}
|
||||
|
||||
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
|
||||
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
|
||||
var zero K
|
||||
arg_if, ok := p.Args[name]
|
||||
if ok == false {
|
||||
return zero, fmt.Errorf("No Arg of name %s", name)
|
||||
}
|
||||
|
||||
arg, ok := arg_if.(K)
|
||||
if ok == false {
|
||||
return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero))
|
||||
}
|
||||
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, error) {
|
||||
var zero K
|
||||
|
||||
arg_list, err := ExtractParam[[]interface{}](p, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := make([]K, len(arg_list))
|
||||
for i, val := range(arg_list) {
|
||||
val_conv, ok := arg_list[i].(K)
|
||||
if ok == false {
|
||||
return nil, fmt.Errorf("Failed to cast arg %s[%d](%+v) to %+v", name, i, val, reflect.TypeOf(zero))
|
||||
}
|
||||
ret[i] = val_conv
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
|
||||
id_str, err := ExtractParam[string](p, name)
|
||||
if err != nil {
|
||||
return ZeroID, err
|
||||
}
|
||||
|
||||
id, err := ParseID(id_str)
|
||||
if err != nil {
|
||||
return ZeroID, err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
|
||||
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, _, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ID, nil
|
||||
}
|
||||
|
||||
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, _, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return string(node.Type), nil
|
||||
}
|
||||
|
||||
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gql_ext, err := GetExt[*GQLExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listen := ""
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
|
||||
listen = gql_ext.Listen
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listen, nil
|
||||
}
|
||||
|
||||
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
thread_ext, err := GetExt[*ThreadExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parent *Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
||||
parent = thread_ext.Parent
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parent, nil
|
||||
}
|
||||
|
||||
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
thread_ext, err := GetExt[*ThreadExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state string
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
|
||||
state = thread_ext.State
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
thread_ext, err := GetExt[*ThreadExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var children []*Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
||||
children = thread_ext.ChildList()
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return children, nil
|
||||
}
|
||||
|
||||
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lockable_ext, err := GetExt[*LockableExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var requirements []*Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
||||
requirements = make([]*Node, len(lockable_ext.Requirements))
|
||||
i := 0
|
||||
for _, req := range(lockable_ext.Requirements) {
|
||||
requirements[i] = req
|
||||
i += 1
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return requirements, nil
|
||||
}
|
||||
|
||||
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lockable_ext, err := GetExt[*LockableExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dependencies []*Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
|
||||
dependencies = make([]*Node, len(lockable_ext.Dependencies))
|
||||
i := 0
|
||||
for _, dep := range(lockable_ext.Dependencies) {
|
||||
dependencies[i] = dep
|
||||
i += 1
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dependencies, nil
|
||||
}
|
||||
|
||||
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lockable_ext, err := GetExt[*LockableExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var owner *Node = nil
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
|
||||
owner = lockable_ext.Owner
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
|
||||
node, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
group_ext, err := GetExt[*GroupExt](node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var members []*Node
|
||||
context := NewReadContext(ctx.Context)
|
||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
|
||||
members = make([]*Node, len(group_ext.Members))
|
||||
i := 0
|
||||
for _, member := range(group_ext.Members) {
|
||||
members[i] = member
|
||||
i += 1
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
||||
if signal, ok := p.Source.(Signal); ok {
|
||||
return fn(signal, p)
|
||||
}
|
||||
return nil, fmt.Errorf("Failed to cast source to event")
|
||||
}
|
||||
|
||||
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
|
||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||
return signal.Type(), nil
|
||||
})
|
||||
}
|
||||
|
||||
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
|
||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||
direction := signal.Direction()
|
||||
if direction == Up {
|
||||
return "up", nil
|
||||
} else if direction == Down {
|
||||
return "down", nil
|
||||
} else if direction == Direct {
|
||||
return "direct", nil
|
||||
}
|
||||
return nil, fmt.Errorf("Invalid direction: %+v", direction)
|
||||
})
|
||||
}
|
||||
|
||||
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
|
||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||
ser, err := signal.Serialize()
|
||||
return string(ser), err
|
||||
})
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
package graphvent
|
||||
import (
|
||||
"github.com/graphql-go/graphql"
|
||||
)
|
||||
|
||||
func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) {
|
||||
return GQLSubscribeFn(p, false, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
||||
return signal, nil
|
||||
})
|
||||
}
|
||||
|
||||
func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) {
|
||||
return GQLSubscribeFn(p, true, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
||||
return server, nil
|
||||
})
|
||||
}
|
||||
|
||||
func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
||||
_, ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := make(chan interface{})
|
||||
go func(c chan interface{}, ext *GQLExt, server *Node) {
|
||||
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
|
||||
sig_c := ext.NewSubscriptionChannel(1)
|
||||
if send_nil == true {
|
||||
sig_c <- nil
|
||||
}
|
||||
for {
|
||||
val, ok := <- sig_c
|
||||
if ok == false {
|
||||
return
|
||||
}
|
||||
ret, err := fn(ctx.Context, server, ext, val, p)
|
||||
if err != nil {
|
||||
ctx.Context.Log.Logf("gqlws", "type convertor error %s", err)
|
||||
return
|
||||
}
|
||||
c <- ret
|
||||
}
|
||||
}(c, ctx.Ext, ctx.Server)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
var GQLSubscriptionSelf = NewField(func()*graphql.Field{
|
||||
gql_subscription_self := &graphql.Field{
|
||||
Type: GQLTypeBaseThread.Type,
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
return p.Source, nil
|
||||
},
|
||||
Subscribe: GQLSubscribeSelf,
|
||||
}
|
||||
|
||||
return gql_subscription_self
|
||||
})
|
||||
|
||||
var GQLSubscriptionUpdate = NewField(func()*graphql.Field{
|
||||
gql_subscription_update := &graphql.Field{
|
||||
Type: GQLTypeSignal.Type,
|
||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||
return p.Source, nil
|
||||
},
|
||||
Subscribe: GQLSubscribeSignal,
|
||||
}
|
||||
return gql_subscription_update
|
||||
})
|
||||
|
@ -1,223 +1,152 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/websocket"
|
||||
"testing"
|
||||
"time"
|
||||
"errors"
|
||||
"crypto/rand"
|
||||
"crypto/ecdh"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
)
|
||||
|
||||
func TestGQLSubscribe(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "gql"})
|
||||
func TestGQL(t * testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "db"})
|
||||
|
||||
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
|
||||
TestUserNodeType := NodeType("TEST_USER")
|
||||
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType})
|
||||
fatalErr(t, err)
|
||||
|
||||
listener_ext := NewListenerExt(10)
|
||||
u1 := NewNode(ctx, RandID(), TestUserNodeType)
|
||||
u1_policy := NewPerNodePolicy(NodeActions{
|
||||
u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"},
|
||||
})
|
||||
u1.Extensions[ACLExtType] = NewACLExt(nil)
|
||||
u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
||||
PerNodePolicyType: &u1_policy,
|
||||
})
|
||||
|
||||
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
|
||||
|
||||
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
|
||||
ListenerNodeType := NodeType("LISTENER")
|
||||
err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ListenerExtType, LockableExtType})
|
||||
fatalErr(t, err)
|
||||
|
||||
query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }"
|
||||
|
||||
ctx.Log.Logf("test", "GQL: %s", gql.ID)
|
||||
ctx.Log.Logf("test", "Node: %s", n1.ID)
|
||||
ctx.Log.Logf("test", "Query: %s", query)
|
||||
|
||||
sub_1 := GQLPayload{
|
||||
Query: query,
|
||||
}
|
||||
|
||||
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
|
||||
url := fmt.Sprintf("http://localhost:%d/gql", port)
|
||||
ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port)
|
||||
|
||||
SubGQL := func(payload GQLPayload) {
|
||||
config, err := websocket.NewConfig(ws_url, url)
|
||||
fatalErr(t, err)
|
||||
config.Protocol = append(config.Protocol, "graphql-ws")
|
||||
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
ws, err := websocket.DialConfig(config)
|
||||
|
||||
fatalErr(t, err)
|
||||
|
||||
type payload_struct struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
init := struct{
|
||||
ID uuid.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}{
|
||||
uuid.New(),
|
||||
"connection_init",
|
||||
}
|
||||
|
||||
ser, err := json.Marshal(&init)
|
||||
fatalErr(t, err)
|
||||
|
||||
_, err = ws.Write(ser)
|
||||
fatalErr(t, err)
|
||||
l1 := NewNode(ctx, RandID(), ListenerNodeType)
|
||||
l1_policy := NewRequirementOfPolicy(NodeActions{
|
||||
l1.ID: Actions{"signal.status"},
|
||||
})
|
||||
|
||||
resp := make([]byte, 1024)
|
||||
n, err := ws.Read(resp)
|
||||
l1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
||||
listener_ext := NewListenerExt(10)
|
||||
l1.Extensions[ListenerExtType] = listener_ext
|
||||
l1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
||||
RequirementOfPolicyType: &l1_policy,
|
||||
})
|
||||
l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
||||
|
||||
var init_resp GQLWSMsg
|
||||
err = json.Unmarshal(resp[:n], &init_resp)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "L1_ID: %s", l1.ID)
|
||||
|
||||
if init_resp.Type != "connection_ack" {
|
||||
t.Fatal("Didn't receive connection_ack")
|
||||
}
|
||||
TestThreadNodeType := NodeType("TEST_THREAD")
|
||||
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType})
|
||||
fatalErr(t, err)
|
||||
|
||||
sub := GQLWSMsg{
|
||||
ID: uuid.New().String(),
|
||||
Type: "subscribe",
|
||||
Payload: sub_1,
|
||||
}
|
||||
t1 := NewNode(ctx, RandID(), TestThreadNodeType)
|
||||
t1_policy := NewParentOfPolicy(NodeActions{
|
||||
t1.ID: Actions{"signal.abort", "state.write"},
|
||||
})
|
||||
t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
||||
t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
||||
ParentOfPolicyType: &t1_policy,
|
||||
})
|
||||
t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil)
|
||||
fatalErr(t, err)
|
||||
t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
||||
|
||||
ser, err = json.Marshal(&sub)
|
||||
fatalErr(t, err)
|
||||
_, err = ws.Write(ser)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
|
||||
|
||||
n, err = ws.Read(resp)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "SUB1: %s", resp[:n])
|
||||
TestGQLNodeType := NodeType("TEST_GQL")
|
||||
err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType})
|
||||
fatalErr(t, err)
|
||||
|
||||
lock_id, err := LockLockable(ctx, gql)
|
||||
fatalErr(t, err)
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
fatalErr(t, err)
|
||||
|
||||
response, _, err := WaitForResponse(listener_ext.Chan, 100*time.Millisecond, lock_id)
|
||||
fatalErr(t, err)
|
||||
gql := NewNode(ctx, RandID(), TestGQLNodeType)
|
||||
gql_policy := NewChildOfPolicy(NodeActions{
|
||||
gql.ID: Actions{"signal.status"},
|
||||
})
|
||||
gql.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
||||
gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
||||
ChildOfPolicyType: &gql_policy,
|
||||
})
|
||||
gql.Extensions[GroupExtType] = NewGroupExt(nil)
|
||||
gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil)
|
||||
gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil)
|
||||
fatalErr(t, err)
|
||||
gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
||||
|
||||
switch response.(type) {
|
||||
case *SuccessSignal:
|
||||
ctx.Log.Logf("test", "Locked %s", gql.ID)
|
||||
default:
|
||||
t.Errorf("Unexpected lock response: %s", response)
|
||||
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
|
||||
info := ParentInfo{true, "start", "restore"}
|
||||
context := NewWriteContext(ctx)
|
||||
err = UpdateStates(context, u1, NewACLInfo(gql, []string{"users"}), func(context *StateContext) error {
|
||||
err := LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{
|
||||
ParentInfoType: &info,
|
||||
}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err = ws.Read(resp)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "SUB2: %s", resp[:n])
|
||||
|
||||
n, err = ws.Read(resp)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "SUB3: %s", resp[:n])
|
||||
|
||||
// TODO: check that there are no more messages sent to ws within a timeout
|
||||
}
|
||||
|
||||
SubGQL(sub_1)
|
||||
}
|
||||
|
||||
func TestGQLQuery(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "lockable"})
|
||||
|
||||
n1_listener := NewListenerExt(10)
|
||||
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil), n1_listener)
|
||||
return LinkLockables(context, u1, l1, []*Node{gql})
|
||||
})
|
||||
fatalErr(t, err)
|
||||
|
||||
gql_listener := NewListenerExt(10)
|
||||
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||
context = NewReadContext(ctx)
|
||||
err = SendSignal(context, gql, gql, NewStatusSignal("child_linked", t1.ID))
|
||||
fatalErr(t, err)
|
||||
|
||||
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener)
|
||||
context = NewReadContext(ctx)
|
||||
err = SendSignal(context, gql, gql, AbortSignal)
|
||||
fatalErr(t, err)
|
||||
|
||||
ctx.Log.Logf("test", "GQL: %s", gql.ID)
|
||||
ctx.Log.Logf("test", "NODE: %s", n1.ID)
|
||||
|
||||
skipVerifyTransport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{Transport: skipVerifyTransport}
|
||||
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
|
||||
url := fmt.Sprintf("http://localhost:%d/gql", port)
|
||||
|
||||
req_1 := GQLPayload{
|
||||
Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { LockableState } } }",
|
||||
Variables: map[string]interface{}{
|
||||
"id": n1.ID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
req_2 := GQLPayload{
|
||||
Query: "query Self { Self { ID, Type, ... on Lockable { LockableState, Requirements { Key { ID ... on Lockable { LockableState } } } } } }",
|
||||
}
|
||||
|
||||
SendGQL := func(payload GQLPayload) []byte {
|
||||
ser, err := json.MarshalIndent(&payload, "", " ")
|
||||
fatalErr(t, err)
|
||||
|
||||
req_data := bytes.NewBuffer(ser)
|
||||
req, err := http.NewRequest("GET", url, req_data)
|
||||
fatalErr(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
fatalErr(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
err = ThreadLoop(ctx, gql, "start")
|
||||
if errors.Is(err, ThreadAbortedError) == false {
|
||||
fatalErr(t, err)
|
||||
|
||||
resp.Body.Close()
|
||||
return body
|
||||
}
|
||||
|
||||
resp_1 := SendGQL(req_1)
|
||||
ctx.Log.Logf("test", "RESP_1: %s", resp_1)
|
||||
resp_2 := SendGQL(req_2)
|
||||
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
|
||||
|
||||
lock_id, err := LockLockable(ctx, n1)
|
||||
fatalErr(t, err)
|
||||
|
||||
response, _, err := WaitForResponse(n1_listener.Chan, 100*time.Millisecond, lock_id)
|
||||
fatalErr(t, err)
|
||||
switch response := response.(type) {
|
||||
case *SuccessSignal:
|
||||
default:
|
||||
t.Fatalf("Wrong response: %s", reflect.TypeOf(response))
|
||||
}
|
||||
|
||||
resp_3 := SendGQL(req_1)
|
||||
ctx.Log.Logf("test", "RESP_3: %s", resp_3)
|
||||
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
|
||||
|
||||
resp_4 := SendGQL(req_2)
|
||||
ctx.Log.Logf("test", "RESP_4: %s", resp_4)
|
||||
}
|
||||
context = NewReadContext(ctx)
|
||||
err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error {
|
||||
ser1, err := gql.Serialize()
|
||||
ser2, err := u1.Serialize()
|
||||
ctx.Log.Logf("test", "\n%s\n\n", ser1)
|
||||
ctx.Log.Logf("test", "\n%s\n\n", ser2)
|
||||
return err
|
||||
})
|
||||
|
||||
func TestGQLDB(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "db", "node", "serialize"})
|
||||
|
||||
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||
fatalErr(t, err)
|
||||
listener_ext := NewListenerExt(10)
|
||||
|
||||
gql, err := ctx.NewNode(nil, "Node", gql_ext, listener_ext)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
|
||||
|
||||
err = ctx.Stop()
|
||||
// Clear all loaded nodes from the context so it loads them from the database
|
||||
ctx.Nodes = NodeMap{}
|
||||
gql_loaded, err := LoadNode(ctx, gql.ID)
|
||||
fatalErr(t, err)
|
||||
context = NewReadContext(ctx)
|
||||
err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
|
||||
ser, err := gql_loaded.Serialize()
|
||||
lockable_ext, err := GetExt[*LockableExt](gql_loaded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Log.Logf("test", "\n%s\n\n", ser)
|
||||
dependency := lockable_ext.Dependencies[l1.ID]
|
||||
listener_ext, err = GetExt[*ListenerExt](dependency)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
SendSignal(context, gql_loaded, gql_loaded, StopSignal)
|
||||
return err
|
||||
})
|
||||
|
||||
gql_loaded, err := ctx.GetNode(gql.ID)
|
||||
err = ThreadLoop(ctx, gql_loaded, "start")
|
||||
fatalErr(t, err)
|
||||
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
|
||||
|
||||
listener_ext, err = GetExt[ListenerExt](gql_loaded)
|
||||
fatalErr(t, err)
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,148 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"github.com/graphql-go/graphql"
|
||||
)
|
||||
|
||||
func AddNodeFields(obj *graphql.Object) {
|
||||
obj.AddFieldConfig("ID", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLNodeID,
|
||||
})
|
||||
|
||||
obj.AddFieldConfig("TypeHash", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLNodeTypeHash,
|
||||
})
|
||||
}
|
||||
|
||||
func AddLockableFields(obj *graphql.Object) {
|
||||
AddNodeFields(obj)
|
||||
|
||||
obj.AddFieldConfig("Requirements", &graphql.Field{
|
||||
Type: GQLInterfaceLockable.List,
|
||||
Resolve: GQLLockableRequirements,
|
||||
})
|
||||
|
||||
obj.AddFieldConfig("Owner", &graphql.Field{
|
||||
Type: GQLInterfaceLockable.Type,
|
||||
Resolve: GQLLockableOwner,
|
||||
})
|
||||
|
||||
obj.AddFieldConfig("Dependencies", &graphql.Field{
|
||||
Type: GQLInterfaceLockable.List,
|
||||
Resolve: GQLLockableDependencies,
|
||||
})
|
||||
}
|
||||
|
||||
func AddThreadFields(obj *graphql.Object) {
|
||||
AddNodeFields(obj)
|
||||
|
||||
obj.AddFieldConfig("State", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLThreadState,
|
||||
})
|
||||
|
||||
obj.AddFieldConfig("Children", &graphql.Field{
|
||||
Type: GQLInterfaceThread.List,
|
||||
Resolve: GQLThreadChildren,
|
||||
})
|
||||
|
||||
obj.AddFieldConfig("Parent", &graphql.Field{
|
||||
Type: GQLInterfaceThread.Type,
|
||||
Resolve: GQLThreadParent,
|
||||
})
|
||||
}
|
||||
|
||||
var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
|
||||
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
|
||||
Name: "SimpleThread",
|
||||
Interfaces: []*graphql.Interface{
|
||||
GQLInterfaceNode.Type,
|
||||
GQLInterfaceThread.Type,
|
||||
GQLInterfaceLockable.Type,
|
||||
},
|
||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||
node, ok := p.Value.(*Node)
|
||||
if ok == false {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := GetExt[*ThreadExt](node)
|
||||
return err == nil
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
AddThreadFields(gql_type_simple_thread)
|
||||
|
||||
return gql_type_simple_thread
|
||||
}, nil)
|
||||
|
||||
var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
|
||||
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
|
||||
Name: "SimpleLockable",
|
||||
Interfaces: []*graphql.Interface{
|
||||
GQLInterfaceNode.Type,
|
||||
GQLInterfaceLockable.Type,
|
||||
},
|
||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||
node, ok := p.Value.(*Node)
|
||||
if ok == false {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := GetExt[*LockableExt](node)
|
||||
return err == nil
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
AddLockableFields(gql_type_simple_lockable)
|
||||
|
||||
return gql_type_simple_lockable
|
||||
}, nil)
|
||||
|
||||
var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
|
||||
object := graphql.NewObject(graphql.ObjectConfig{
|
||||
Name: "SimpleNode",
|
||||
Interfaces: []*graphql.Interface{
|
||||
GQLInterfaceNode.Type,
|
||||
},
|
||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||
_, ok := p.Value.(*Node)
|
||||
return ok
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
AddNodeFields(object)
|
||||
|
||||
return object
|
||||
}, nil)
|
||||
|
||||
var GQLTypeSignal = NewSingleton(func() *graphql.Object {
|
||||
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
|
||||
Name: "Signal",
|
||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||
_, ok := p.Value.(Signal)
|
||||
return ok
|
||||
},
|
||||
Fields: graphql.Fields{},
|
||||
})
|
||||
|
||||
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLSignalType,
|
||||
})
|
||||
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLSignalDirection,
|
||||
})
|
||||
gql_type_signal.AddFieldConfig("String", &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Resolve: GQLSignalString,
|
||||
})
|
||||
return gql_type_signal
|
||||
}, nil)
|
||||
|
@ -1,66 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// A Listener extension provides a channel that can receive signals on a different thread
|
||||
type ListenerExt struct {
|
||||
Buffer int `gv:"buffer"`
|
||||
Chan chan Signal
|
||||
}
|
||||
|
||||
type LoadedSignal struct {
|
||||
SignalHeader
|
||||
}
|
||||
|
||||
func NewLoadedSignal() *LoadedSignal {
|
||||
return &LoadedSignal{
|
||||
SignalHeader: NewSignalHeader(),
|
||||
}
|
||||
}
|
||||
|
||||
type UnloadedSignal struct {
|
||||
SignalHeader
|
||||
}
|
||||
|
||||
func NewUnloadedSignal() *UnloadedSignal {
|
||||
return &UnloadedSignal{
|
||||
SignalHeader: NewSignalHeader(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
|
||||
ext.Chan = make(chan Signal, ext.Buffer)
|
||||
ext.Chan <- NewLoadedSignal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
|
||||
ext.Chan <- NewUnloadedSignal()
|
||||
close(ext.Chan)
|
||||
}
|
||||
|
||||
// Create a new listener extension with a given buffer size
|
||||
func NewListenerExt(buffer int) *ListenerExt {
|
||||
return &ListenerExt{
|
||||
Buffer: buffer,
|
||||
Chan: make(chan Signal, buffer),
|
||||
}
|
||||
}
|
||||
|
||||
// Send the signal to the channel, logging an overflow if it occurs
|
||||
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
||||
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
|
||||
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
|
||||
select {
|
||||
case ext.Chan <- signal:
|
||||
default:
|
||||
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
|
||||
}
|
||||
switch sig := signal.(type) {
|
||||
case *StatusSignal:
|
||||
ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
@ -1,411 +1,596 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ReqState byte
|
||||
const (
|
||||
Unlocked = ReqState(0)
|
||||
Unlocking = ReqState(1)
|
||||
Locked = ReqState(2)
|
||||
Locking = ReqState(3)
|
||||
AbortingLock = ReqState(4)
|
||||
)
|
||||
type ListenerExt struct {
|
||||
Buffer int
|
||||
Chan chan Signal
|
||||
}
|
||||
|
||||
func NewListenerExt(buffer int) *ListenerExt {
|
||||
return &ListenerExt{
|
||||
Buffer: buffer,
|
||||
Chan: make(chan Signal, buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j int
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewListenerExt(j), nil
|
||||
}
|
||||
|
||||
var ReqStateStrings = map[ReqState]string {
|
||||
Unlocked: "Unlocked",
|
||||
Unlocking: "Unlocking",
|
||||
Locked: "Locked",
|
||||
Locking: "Locking",
|
||||
AbortingLock: "AbortingLock",
|
||||
const ListenerExtType = ExtType("LISTENER")
|
||||
func (listener *ListenerExt) Type() ExtType {
|
||||
return ListenerExtType
|
||||
}
|
||||
|
||||
func (state ReqState) String() string {
|
||||
str, mapped := ReqStateStrings[state]
|
||||
if mapped == false {
|
||||
return "UNKNOWN_REQSTATE"
|
||||
} else {
|
||||
return str
|
||||
func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
select {
|
||||
case ext.Chan <- signal:
|
||||
default:
|
||||
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *ListenerExt) Serialize() ([]byte, error) {
|
||||
return json.MarshalIndent(ext.Buffer, "", " ")
|
||||
}
|
||||
|
||||
type LockableExt struct{
|
||||
State ReqState `gv:"state"`
|
||||
ReqID *uuid.UUID `gv:"req_id"`
|
||||
Owner *NodeID `gv:"owner"`
|
||||
PendingOwner *NodeID `gv:"pending_owner"`
|
||||
Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
|
||||
type LockableExt struct {
|
||||
Owner *Node
|
||||
Requirements map[NodeID]*Node
|
||||
Dependencies map[NodeID]*Node
|
||||
LocksHeld map[NodeID]*Node
|
||||
}
|
||||
|
||||
Locked map[NodeID]any
|
||||
Unlocked map[NodeID]any
|
||||
const LockableExtType = ExtType("LOCKABLE")
|
||||
func (ext *LockableExt) Type() ExtType {
|
||||
return LockableExtType
|
||||
}
|
||||
|
||||
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
|
||||
type LockableExtJSON struct {
|
||||
Owner string `json:"owner"`
|
||||
Requirements []string `json:"requirements"`
|
||||
Dependencies []string `json:"dependencies"`
|
||||
LocksHeld map[string]string `json:"locks_held"`
|
||||
}
|
||||
|
||||
func NewLockableExt(requirements []NodeID) *LockableExt {
|
||||
var reqs map[NodeID]ReqState = nil
|
||||
var unlocked map[NodeID]any = map[NodeID]any{}
|
||||
func (ext *LockableExt) Serialize() ([]byte, error) {
|
||||
return json.MarshalIndent(&LockableExtJSON{
|
||||
Owner: SaveNode(ext.Owner),
|
||||
Requirements: SaveNodeList(ext.Requirements),
|
||||
Dependencies: SaveNodeList(ext.Dependencies),
|
||||
LocksHeld: SaveNodeMap(ext.LocksHeld),
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
if len(requirements) != 0 {
|
||||
reqs = map[NodeID]ReqState{}
|
||||
for _, req := range(requirements) {
|
||||
reqs[req] = Unlocked
|
||||
unlocked[req] = nil
|
||||
}
|
||||
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
|
||||
if requirements == nil {
|
||||
requirements = NodeMap{}
|
||||
}
|
||||
|
||||
return &LockableExt{
|
||||
State: Unlocked,
|
||||
Owner: nil,
|
||||
PendingOwner: nil,
|
||||
Requirements: reqs,
|
||||
Waiting: WaitMap{},
|
||||
if dependencies == nil {
|
||||
dependencies = NodeMap{}
|
||||
}
|
||||
|
||||
Locked: map[NodeID]any{},
|
||||
Unlocked: unlocked,
|
||||
if locks_held == nil {
|
||||
locks_held = NodeMap{}
|
||||
}
|
||||
}
|
||||
|
||||
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
||||
signal := NewUnlockSignal()
|
||||
messages := []Message{{node.ID, signal}}
|
||||
return signal.ID(), ctx.Send(node, messages)
|
||||
return &LockableExt{
|
||||
Owner: owner,
|
||||
Requirements: requirements,
|
||||
Dependencies: dependencies,
|
||||
LocksHeld: locks_held,
|
||||
}
|
||||
}
|
||||
|
||||
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
||||
signal := NewLockSignal()
|
||||
messages := []Message{{node.ID, signal}}
|
||||
return signal.ID(), ctx.Send(node, messages)
|
||||
}
|
||||
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j LockableExtJSON
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (ext *LockableExt) Load(ctx *Context, node *Node) error {
|
||||
ext.Locked = map[NodeID]any{}
|
||||
ext.Unlocked = map[NodeID]any{}
|
||||
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j)
|
||||
|
||||
for id, state := range(ext.Requirements) {
|
||||
if state == Unlocked {
|
||||
ext.Unlocked[id] = nil
|
||||
} else if state == Locked {
|
||||
ext.Locked[id] = nil
|
||||
}
|
||||
owner, err := RestoreNode(ctx, j.Owner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *LockableExt) Unload(ctx *Context, node *Node) {
|
||||
return
|
||||
requirements, err := RestoreNodeList(ctx, j.Requirements)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
|
||||
}
|
||||
|
||||
// Handle link signal by adding/removing the requested NodeID
|
||||
// returns an error if the node is not unlocked
|
||||
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
|
||||
switch ext.State {
|
||||
case Unlocked:
|
||||
switch signal.Action {
|
||||
case "add":
|
||||
_, exists := ext.Requirements[signal.NodeID]
|
||||
if exists == true {
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_requirement")})
|
||||
} else {
|
||||
if ext.Requirements == nil {
|
||||
ext.Requirements = map[NodeID]ReqState{}
|
||||
func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
|
||||
|
||||
var err error
|
||||
switch signal.Direction() {
|
||||
case Up:
|
||||
err = UseStates(context, node,
|
||||
NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
|
||||
owner_sent := false
|
||||
for _, dependency := range(ext.Dependencies) {
|
||||
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
|
||||
SendSignal(context, dependency, node, signal)
|
||||
if ext.Owner != nil {
|
||||
if dependency.ID == ext.Owner.ID {
|
||||
owner_sent = true
|
||||
}
|
||||
}
|
||||
ext.Requirements[signal.NodeID] = Unlocked
|
||||
changes = append(changes, "requirements")
|
||||
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
||||
}
|
||||
case "remove":
|
||||
_, exists := ext.Requirements[signal.NodeID]
|
||||
if exists == false {
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
|
||||
} else {
|
||||
delete(ext.Requirements, signal.NodeID)
|
||||
changes = append(changes, "requirements")
|
||||
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
||||
if ext.Owner != nil && owner_sent == false {
|
||||
if ext.Owner.ID != node.ID {
|
||||
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
|
||||
return SendSignal(context, ext.Owner, node, signal)
|
||||
}
|
||||
}
|
||||
default:
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unknown_action")})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case Down:
|
||||
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
||||
for _, requirement := range(ext.Requirements) {
|
||||
err := SendSignal(context, requirement, node, signal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case Direct:
|
||||
err = nil
|
||||
default:
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocked: %s", ext.State)})
|
||||
err = fmt.Errorf("invalid signal direction %d", signal.Direction())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
|
||||
last_owner, exists := ext.LocksHeld[node.ID]
|
||||
if exists == false {
|
||||
panic("Attempted to take a get the original lock holder of a lockable we don't own")
|
||||
}
|
||||
delete(ext.LocksHeld, node.ID)
|
||||
return last_owner
|
||||
}
|
||||
|
||||
// Handle an UnlockSignal by either transitioning to Unlocked state,
|
||||
// sending unlock signals to requirements, or returning an error signal
|
||||
func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source NodeID, signal *UnlockSignal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
|
||||
_, exists := ext.LocksHeld[node.ID]
|
||||
if exists == true {
|
||||
panic("Attempted to lock a lockable we're already holding(lock cycle)")
|
||||
}
|
||||
ext.LocksHeld[node.ID] = last_owner
|
||||
}
|
||||
|
||||
switch ext.State {
|
||||
case Locked:
|
||||
if source != *ext.Owner {
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_owner")})
|
||||
} else {
|
||||
if len(ext.Requirements) == 0 {
|
||||
changes = append(changes, "state", "owner", "pending_owner")
|
||||
// Removes requirement as a requirement from lockable
|
||||
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
|
||||
lockable_ext, err := GetExt[*LockableExt](lockable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requirement_ext, err := GetExt[*LockableExt](requirement)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return UpdateStates(context, princ, ACLMap{
|
||||
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
|
||||
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
|
||||
}, func(context *StateContext) error {
|
||||
var found *Node = nil
|
||||
for _, req := range(lockable_ext.Requirements) {
|
||||
if requirement.ID == req.ID {
|
||||
found = req
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ext.Owner = nil
|
||||
if found == nil {
|
||||
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
|
||||
}
|
||||
|
||||
ext.PendingOwner = nil
|
||||
delete(requirement_ext.Dependencies, lockable.ID)
|
||||
delete(lockable_ext.Requirements, requirement.ID)
|
||||
|
||||
ext.State = Unlocked
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
||||
} else {
|
||||
changes = append(changes, "state", "waiting", "requirements", "pending_owner")
|
||||
// Link requirements as requirements to lockable
|
||||
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
|
||||
if lockable == nil {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
|
||||
}
|
||||
|
||||
ext.PendingOwner = nil
|
||||
if len(requirements) == 0 {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
|
||||
}
|
||||
|
||||
ext.ReqID = &signal.Id
|
||||
lockable_ext, err := GetExt[*LockableExt](lockable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ext.State = Unlocking
|
||||
for id := range(ext.Requirements) {
|
||||
unlock_signal := NewUnlockSignal()
|
||||
req_exts := map[NodeID]*LockableExt{}
|
||||
for _, requirement := range(requirements) {
|
||||
if requirement == nil {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
|
||||
}
|
||||
|
||||
ext.Waiting[unlock_signal.Id] = id
|
||||
ext.Requirements[id] = Unlocking
|
||||
if lockable.ID == requirement.ID {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
|
||||
}
|
||||
|
||||
messages = append(messages, Message{id, unlock_signal})
|
||||
}
|
||||
}
|
||||
_, exists := req_exts[requirement.ID]
|
||||
if exists == true {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
|
||||
}
|
||||
default:
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_locked")})
|
||||
ext, err := GetExt[*LockableExt](requirement)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req_exts[requirement.ID] = ext
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
}
|
||||
|
||||
// Handle a LockSignal by either transitioning to a locked state,
|
||||
// sending lock signals to requirements, or returning an error signal
|
||||
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
return UpdateStates(context, princ, NewACLMap(
|
||||
NewACLInfo(lockable, []string{"requirements"}),
|
||||
ACLList(requirements, []string{"dependencies"}),
|
||||
), func(context *StateContext) error {
|
||||
// Check that all the requirements can be added
|
||||
// If the lockable is already locked, need to lock this resource as well before we can add it
|
||||
for _, requirement := range(requirements) {
|
||||
requirement_ext := req_exts[requirement.ID]
|
||||
for _, req := range(requirements) {
|
||||
if req.ID == requirement.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
switch ext.State {
|
||||
case Unlocked:
|
||||
if len(ext.Requirements) == 0 {
|
||||
changes = append(changes, "state", "owner", "pending_owner")
|
||||
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if is_req {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
|
||||
|
||||
ext.Owner = &source
|
||||
}
|
||||
}
|
||||
|
||||
ext.PendingOwner = &source
|
||||
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if is_req {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
|
||||
}
|
||||
|
||||
ext.State = Locked
|
||||
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
||||
} else {
|
||||
changes = append(changes, "state", "requirements", "waiting", "pending_owner")
|
||||
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if is_req {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
|
||||
}
|
||||
|
||||
ext.PendingOwner = &source
|
||||
if lockable_ext.Owner == nil {
|
||||
// If the new owner isn't locked, we can add the requirement
|
||||
} else if requirement_ext.Owner == nil {
|
||||
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
|
||||
} else {
|
||||
// If the new requirement is already locked and the owner is already locked, their owners need to match
|
||||
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
|
||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update the states of the requirements
|
||||
for _, requirement := range(requirements) {
|
||||
requirement_ext := req_exts[requirement.ID]
|
||||
requirement_ext.Dependencies[lockable.ID] = lockable
|
||||
lockable_ext.Requirements[lockable.ID] = requirement
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
|
||||
}
|
||||
|
||||
ext.ReqID = &signal.Id
|
||||
// Return no error
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
ext.State = Locking
|
||||
for id := range(ext.Requirements) {
|
||||
lock_signal := NewLockSignal()
|
||||
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
|
||||
for _, req := range(cur.Requirements) {
|
||||
if req.ID == id {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
ext.Waiting[lock_signal.Id] = id
|
||||
ext.Requirements[id] = Locking
|
||||
req_ext, err := GetExt[*LockableExt](req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
messages = append(messages, Message{id, lock_signal})
|
||||
}
|
||||
var is_req bool
|
||||
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
|
||||
is_req, err = checkIfRequirement(context, id, req_ext)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if is_req == true {
|
||||
return true, nil
|
||||
}
|
||||
default:
|
||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_unlocked: %s", ext.State)})
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Handle an error signal by aborting the lock, or retrying the unlock
|
||||
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
|
||||
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
|
||||
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
|
||||
if to_lock == nil {
|
||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
|
||||
}
|
||||
|
||||
id, waiting := ext.Waiting[signal.ReqID]
|
||||
if waiting == true {
|
||||
delete(ext.Waiting, signal.ReqID)
|
||||
changes = append(changes, "waiting")
|
||||
req_exts := map[NodeID]*LockableExt{}
|
||||
for _, l := range(to_lock) {
|
||||
var err error
|
||||
if l == nil {
|
||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
|
||||
}
|
||||
|
||||
switch ext.State {
|
||||
case Locking:
|
||||
changes = append(changes, "state", "requirements")
|
||||
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ext.Requirements[id] = Unlocked
|
||||
if new_owner == nil {
|
||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
|
||||
}
|
||||
|
||||
unlocked := 0
|
||||
for req_id, req_state := range(ext.Requirements) {
|
||||
// Unlock locked requirements, and count unlocked requirements
|
||||
switch req_state {
|
||||
case Locked:
|
||||
unlock_signal := NewUnlockSignal()
|
||||
new_owner_ext, err := GetExt[*LockableExt](new_owner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ext.Waiting[unlock_signal.Id] = req_id
|
||||
ext.Requirements[req_id] = Unlocking
|
||||
// Called with no requirements to lock, success
|
||||
if len(to_lock) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages = append(messages, Message{req_id, unlock_signal})
|
||||
case Unlocked:
|
||||
unlocked += 1
|
||||
return UpdateStates(context, new_owner, NewACLMap(
|
||||
ACLListM(to_lock, []string{"lock"}),
|
||||
NewACLInfo(new_owner, nil),
|
||||
), func(context *StateContext) error {
|
||||
// First loop is to check that the states can be locked, and locks all requirements
|
||||
for _, req := range(to_lock) {
|
||||
req_ext := req_exts[req.ID]
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
|
||||
|
||||
// If req is alreay locked, check that we can pass the lock
|
||||
if req_ext.Owner != nil {
|
||||
owner := req_ext.Owner
|
||||
if owner.ID == new_owner.ID {
|
||||
continue
|
||||
} else {
|
||||
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
|
||||
return LockLockables(context, req_ext.Requirements, req)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := LockLockables(context, req_ext.Requirements, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if unlocked == len(ext.Requirements) {
|
||||
changes = append(changes, "owner", "state")
|
||||
ext.State = Unlocked
|
||||
ext.Owner = nil
|
||||
// At this point state modification will be started, so no errors can be returned
|
||||
for _, req := range(to_lock) {
|
||||
req_ext := req_exts[req.ID]
|
||||
old_owner := req_ext.Owner
|
||||
// If the lockable was previously unowned, update the state
|
||||
if old_owner == nil {
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
|
||||
req_ext.Owner = new_owner
|
||||
new_owner_ext.RecordLock(req, old_owner)
|
||||
// Otherwise if the new owner already owns it, no need to update state
|
||||
} else if old_owner.ID == new_owner.ID {
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
|
||||
// Otherwise update the state
|
||||
} else {
|
||||
changes = append(changes, "state")
|
||||
ext.State = AbortingLock
|
||||
req_ext.Owner = new_owner
|
||||
new_owner_ext.RecordLock(req, old_owner)
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
case Unlocking:
|
||||
unlock_signal := NewUnlockSignal()
|
||||
ext.Waiting[unlock_signal.Id] = id
|
||||
messages = append(messages, Message{id, unlock_signal})
|
||||
|
||||
case AbortingLock:
|
||||
req_state := ext.Requirements[id]
|
||||
// Mark failed lock as Unlocked, or retry unlock
|
||||
switch req_state {
|
||||
case Locking:
|
||||
ext.Requirements[id] = Unlocked
|
||||
|
||||
// Check if all requirements unlocked now
|
||||
unlocked := 0
|
||||
for _, req_state := range(ext.Requirements) {
|
||||
if req_state == Unlocked {
|
||||
unlocked += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if unlocked == len(ext.Requirements) {
|
||||
changes = append(changes, "owner", "state")
|
||||
ext.State = Unlocked
|
||||
ext.Owner = nil
|
||||
}
|
||||
case Unlocking:
|
||||
// Handle error for unlocking requirement while unlocking by retrying unlock
|
||||
unlock_signal := NewUnlockSignal()
|
||||
ext.Waiting[unlock_signal.Id] = id
|
||||
messages = append(messages, Message{id, unlock_signal})
|
||||
}
|
||||
}
|
||||
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
|
||||
if to_unlock == nil {
|
||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
}
|
||||
req_exts := map[NodeID]*LockableExt{}
|
||||
for _, l := range(to_unlock) {
|
||||
if l == nil {
|
||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
|
||||
}
|
||||
|
||||
// Handle a success signal by checking if all requirements have been locked/unlocked
|
||||
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
var err error
|
||||
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
id, waiting := ext.Waiting[signal.ReqID]
|
||||
if waiting == true {
|
||||
delete(ext.Waiting, signal.ReqID)
|
||||
changes = append(changes, "waiting")
|
||||
if old_owner == nil {
|
||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
|
||||
}
|
||||
|
||||
switch ext.State {
|
||||
case Locking:
|
||||
ext.Requirements[id] = Locked
|
||||
ext.Locked[id] = nil
|
||||
delete(ext.Unlocked, id)
|
||||
old_owner_ext, err := GetExt[*LockableExt](old_owner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ext.Locked) == len(ext.Requirements) {
|
||||
ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked))
|
||||
changes = append(changes, "state", "owner", "req_id")
|
||||
ext.State = Locked
|
||||
|
||||
ext.Owner = ext.PendingOwner
|
||||
// Called with no requirements to unlock, success
|
||||
if len(to_unlock) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
||||
ext.ReqID = nil
|
||||
return UpdateStates(context, old_owner, NewACLMap(
|
||||
ACLListM(to_unlock, []string{"lock"}),
|
||||
NewACLInfo(old_owner, nil),
|
||||
), func(context *StateContext) error {
|
||||
// First loop is to check that the states can be locked, and locks all requirements
|
||||
for _, req := range(to_unlock) {
|
||||
req_ext := req_exts[req.ID]
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
|
||||
|
||||
// Check if the owner is correct
|
||||
if req_ext.Owner != nil {
|
||||
if req_ext.Owner.ID != old_owner.ID {
|
||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
|
||||
}
|
||||
} else {
|
||||
ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements))
|
||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
|
||||
}
|
||||
case AbortingLock:
|
||||
req_state := ext.Requirements[id]
|
||||
switch req_state {
|
||||
case Locking:
|
||||
ext.Requirements[id] = Unlocking
|
||||
unlock_signal := NewUnlockSignal()
|
||||
ext.Waiting[unlock_signal.Id] = id
|
||||
messages = append(messages, Message{id, unlock_signal})
|
||||
case Unlocking:
|
||||
ext.Requirements[id] = Unlocked
|
||||
ext.Unlocked[id] = nil
|
||||
delete(ext.Locked, id)
|
||||
|
||||
unlocked := 0
|
||||
for _, req_state := range(ext.Requirements) {
|
||||
switch req_state {
|
||||
case Unlocked:
|
||||
unlocked += 1
|
||||
}
|
||||
}
|
||||
|
||||
if unlocked == len(ext.Requirements) {
|
||||
changes = append(changes, "state", "pending_owner", "req_id")
|
||||
err := UnlockLockables(context, req_ext.Requirements, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, Message{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)})
|
||||
ext.State = Unlocked
|
||||
ext.ReqID = nil
|
||||
ext.PendingOwner = nil
|
||||
}
|
||||
// At this point state modification will be started, so no errors can be returned
|
||||
for _, req := range(to_unlock) {
|
||||
req_ext := req_exts[req.ID]
|
||||
new_owner := old_owner_ext.RecordUnlock(req)
|
||||
req_ext.Owner = new_owner
|
||||
if new_owner == nil {
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
|
||||
} else {
|
||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func SaveNode(node *Node) string {
|
||||
str := ""
|
||||
if node != nil {
|
||||
str = node.ID.String()
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
|
||||
if id_str == "" {
|
||||
return nil, nil
|
||||
}
|
||||
id, err := ParseID(id_str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return LoadNode(ctx, id)
|
||||
}
|
||||
|
||||
case Unlocking:
|
||||
ext.Requirements[id] = Unlocked
|
||||
ext.Unlocked[id] = Unlocked
|
||||
delete(ext.Locked, id)
|
||||
func SaveNodeMap(nodes NodeMap) map[string]string {
|
||||
m := map[string]string{}
|
||||
for id, node := range(nodes) {
|
||||
m[id.String()] = SaveNode(node)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
if len(ext.Unlocked) == len(ext.Requirements) {
|
||||
changes = append(changes, "state", "owner", "req_id")
|
||||
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
|
||||
nodes := NodeMap{}
|
||||
for id_str_1, id_str_2 := range(ids) {
|
||||
id_1, err := ParseID(id_str_1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
||||
ext.State = Unlocked
|
||||
ext.ReqID = nil
|
||||
ext.Owner = nil
|
||||
node_1, err := LoadNode(ctx, id_1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var node_2 *Node = nil
|
||||
if id_str_2 != "" {
|
||||
id_2, err := ParseID(id_str_2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node_2, err = LoadNode(ctx, id_2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
nodes[node_1.ID] = node_2
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
||||
var messages []Message = nil
|
||||
var changes Changes = nil
|
||||
func SaveNodeList(nodes NodeMap) []string {
|
||||
ids := make([]string, len(nodes))
|
||||
i := 0
|
||||
for id, _ := range(nodes) {
|
||||
ids[i] = id.String()
|
||||
i += 1
|
||||
}
|
||||
|
||||
switch sig := signal.(type) {
|
||||
case *StatusSignal:
|
||||
// Forward StatusSignals up to the owner(unless that would be a cycle)
|
||||
if ext.Owner != nil {
|
||||
if *ext.Owner != node.ID {
|
||||
messages = append(messages, Message{*ext.Owner, signal})
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
|
||||
nodes := NodeMap{}
|
||||
|
||||
for _, id_str := range(ids) {
|
||||
node, err := RestoreNode(ctx, id_str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *LinkSignal:
|
||||
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
|
||||
case *LockSignal:
|
||||
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
|
||||
case *UnlockSignal:
|
||||
messages, changes = ext.HandleUnlockSignal(ctx, node, source, sig)
|
||||
case *ErrorSignal:
|
||||
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
|
||||
case *SuccessSignal:
|
||||
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
nodes[node.ID] = node
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
|
@ -1,148 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLink(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"lockable", "listener"})
|
||||
|
||||
|
||||
l2_listener := NewListenerExt(10)
|
||||
l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
|
||||
fatalErr(t, err)
|
||||
|
||||
l1_lockable := NewLockableExt(nil)
|
||||
l1_listener := NewListenerExt(10)
|
||||
l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
|
||||
fatalErr(t, err)
|
||||
|
||||
link_signal := NewLinkSignal("add", l2.ID)
|
||||
msgs := []Message{{l1.ID, link_signal}}
|
||||
err = ctx.Send(l1, msgs)
|
||||
fatalErr(t, err)
|
||||
|
||||
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
|
||||
fatalErr(t, err)
|
||||
|
||||
state, exists := l1_lockable.Requirements[l2.ID]
|
||||
if exists == false {
|
||||
t.Fatal("l2 not in l1 requirements")
|
||||
} else if state != Unlocked {
|
||||
t.Fatalf("l2 in bad requirement state in l1: %+v", state)
|
||||
}
|
||||
|
||||
unlink_signal := NewLinkSignal("remove", l2.ID)
|
||||
msgs = []Message{{l1.ID, unlink_signal}}
|
||||
err = ctx.Send(l1, msgs)
|
||||
fatalErr(t, err)
|
||||
|
||||
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
|
||||
fatalErr(t, err)
|
||||
}
|
||||
|
||||
func Test10Lock(t *testing.T) {
|
||||
testLockN(t, 10)
|
||||
}
|
||||
|
||||
func Test100Lock(t *testing.T) {
|
||||
testLockN(t, 100)
|
||||
}
|
||||
|
||||
func Test1000Lock(t *testing.T) {
|
||||
testLockN(t, 1000)
|
||||
}
|
||||
|
||||
func Test10000Lock(t *testing.T) {
|
||||
testLockN(t, 10000)
|
||||
}
|
||||
|
||||
func testLockN(t *testing.T, n int) {
|
||||
ctx := logTestContext(t, []string{"test"})
|
||||
|
||||
NewLockable := func()(*Node) {
|
||||
l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
|
||||
fatalErr(t, err)
|
||||
return l
|
||||
}
|
||||
|
||||
reqs := make([]NodeID, n)
|
||||
for i := range(reqs) {
|
||||
new_lockable := NewLockable()
|
||||
reqs[i] = new_lockable.ID
|
||||
}
|
||||
ctx.Log.Logf("test", "CREATED_%d", n)
|
||||
|
||||
listener := NewListenerExt(50000)
|
||||
node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "CREATED_LISTENER")
|
||||
|
||||
lock_id, err := LockLockable(ctx, node)
|
||||
fatalErr(t, err)
|
||||
|
||||
response, _, err := WaitForResponse(listener.Chan, time.Second*60, lock_id)
|
||||
fatalErr(t, err)
|
||||
|
||||
switch resp := response.(type) {
|
||||
case *SuccessSignal:
|
||||
default:
|
||||
t.Fatalf("Unexpected response to lock - %s", resp)
|
||||
}
|
||||
|
||||
ctx.Log.Logf("test", "LOCKED_%d", n)
|
||||
}
|
||||
|
||||
func TestLock(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "lockable"})
|
||||
|
||||
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
|
||||
listener := NewListenerExt(10000)
|
||||
l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
||||
fatalErr(t, err)
|
||||
return l, listener
|
||||
}
|
||||
|
||||
l2, _ := NewLockable(nil)
|
||||
l3, _ := NewLockable(nil)
|
||||
l4, _ := NewLockable(nil)
|
||||
l5, _ := NewLockable(nil)
|
||||
l0, l0_listener := NewLockable([]NodeID{l5.ID})
|
||||
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
|
||||
|
||||
ctx.Log.Logf("test", "l0: %s", l0.ID)
|
||||
ctx.Log.Logf("test", "l1: %s", l1.ID)
|
||||
ctx.Log.Logf("test", "l2: %s", l2.ID)
|
||||
ctx.Log.Logf("test", "l3: %s", l3.ID)
|
||||
ctx.Log.Logf("test", "l4: %s", l4.ID)
|
||||
ctx.Log.Logf("test", "l5: %s", l5.ID)
|
||||
|
||||
ctx.Log.Logf("test", "locking l0")
|
||||
id_1, err := LockLockable(ctx, l0)
|
||||
fatalErr(t, err)
|
||||
response, _, err := WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "l0 lock: %+v", response)
|
||||
|
||||
ctx.Log.Logf("test", "locking l1")
|
||||
id_2, err := LockLockable(ctx, l1)
|
||||
fatalErr(t, err)
|
||||
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
||||
|
||||
ctx.Log.Logf("test", "unlocking l0")
|
||||
id_3, err := UnlockLockable(ctx, l0)
|
||||
fatalErr(t, err)
|
||||
response, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "l0 unlock: %+v", response)
|
||||
|
||||
ctx.Log.Logf("test", "locking l1")
|
||||
id_4, err := LockLockable(ctx, l1)
|
||||
fatalErr(t, err)
|
||||
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
type Message struct {
|
||||
Node NodeID
|
||||
Signal Signal
|
||||
}
|
||||
|
||||
type MessageQueue struct {
|
||||
out chan<- Message
|
||||
in <-chan Message
|
||||
buffer []Message
|
||||
write_cursor int
|
||||
read_cursor int
|
||||
}
|
||||
|
||||
func (queue *MessageQueue) ProcessIncoming(message Message) {
|
||||
if (queue.write_cursor + 1) == queue.read_cursor || ((queue.write_cursor + 1) == len(queue.buffer) && queue.read_cursor == 0) {
|
||||
new_buffer := make([]Message, len(queue.buffer) * 2)
|
||||
|
||||
copy(new_buffer, queue.buffer[queue.read_cursor:])
|
||||
first_chunk := len(queue.buffer) - queue.read_cursor
|
||||
copy(new_buffer[first_chunk:], queue.buffer[0:queue.write_cursor])
|
||||
|
||||
queue.write_cursor = len(queue.buffer) - 1
|
||||
queue.read_cursor = 0
|
||||
queue.buffer = new_buffer
|
||||
}
|
||||
|
||||
queue.buffer[queue.write_cursor] = message
|
||||
queue.write_cursor += 1
|
||||
if queue.write_cursor >= len(queue.buffer) {
|
||||
queue.write_cursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
func NewMessageQueue(initial int) (chan<- Message, <-chan Message) {
|
||||
in := make(chan Message, 0)
|
||||
out := make(chan Message, 0)
|
||||
|
||||
queue := MessageQueue{
|
||||
out: out,
|
||||
in: in,
|
||||
buffer: make([]Message, initial),
|
||||
write_cursor: 0,
|
||||
read_cursor: 0,
|
||||
}
|
||||
|
||||
go func(queue *MessageQueue) {
|
||||
for true {
|
||||
if queue.write_cursor != queue.read_cursor {
|
||||
select {
|
||||
case incoming := <-queue.in:
|
||||
queue.ProcessIncoming(incoming)
|
||||
case queue.out <- queue.buffer[queue.read_cursor]:
|
||||
queue.read_cursor += 1
|
||||
if queue.read_cursor >= len(queue.buffer) {
|
||||
queue.read_cursor = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message := <-queue.in
|
||||
queue.ProcessIncoming(message)
|
||||
}
|
||||
}
|
||||
}(&queue)
|
||||
|
||||
return in, out
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func sendBatch(start, end uint64, in chan<- Message) {
|
||||
for i := start; i <= end; i++ {
|
||||
var id NodeID
|
||||
binary.BigEndian.PutUint64(id[:], i)
|
||||
in <- Message{id, nil}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageQueue(t *testing.T) {
|
||||
in, out := NewMessageQueue(10)
|
||||
|
||||
for i := uint64(0); i < 1000; i++ {
|
||||
go sendBatch(1000*i, (1000*(i+1))-1, in)
|
||||
}
|
||||
|
||||
seen := map[NodeID]any{}
|
||||
for i := uint64(0); i < 1000*1000; i++ {
|
||||
read := <-out
|
||||
_, already_seen := seen[read.Node]
|
||||
if already_seen {
|
||||
t.Fatalf("Signal %d had duplicate NodeID %s", i, read.Node)
|
||||
} else {
|
||||
seen[read.Node] = nil
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Processed 1M signals through queue")
|
||||
}
|
@ -0,0 +1,396 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Policy interface {
|
||||
Serializable[PolicyType]
|
||||
Allows(context *StateContext, principal *Node, action string, node *Node) bool
|
||||
}
|
||||
|
||||
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
|
||||
type RequirementOfPolicy struct {
|
||||
PerNodePolicy
|
||||
}
|
||||
func (policy *RequirementOfPolicy) Type() PolicyType {
|
||||
return RequirementOfPolicyType
|
||||
}
|
||||
|
||||
func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy {
|
||||
return RequirementOfPolicy{
|
||||
PerNodePolicy: NewPerNodePolicy(nodes),
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any of principals dependencies are in the policy
|
||||
func (policy *RequirementOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||
lockable_ext, err := GetExt[*LockableExt](principal)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for dep_id, _ := range(lockable_ext.Dependencies) {
|
||||
for node_id, actions := range(policy.NodeActions) {
|
||||
if node_id == dep_id {
|
||||
if actions.Allows(action) == true {
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const ChildOfPolicyType = PolicyType("CHILD_OF")
|
||||
type ChildOfPolicy struct {
|
||||
PerNodePolicy
|
||||
}
|
||||
func (policy *ChildOfPolicy) Type() PolicyType {
|
||||
return ChildOfPolicyType
|
||||
}
|
||||
|
||||
func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||
context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy)
|
||||
thread_ext, err := GetExt[*ThreadExt](principal)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
parent := thread_ext.Parent
|
||||
if parent != nil {
|
||||
actions, exists := policy.NodeActions[parent.ID]
|
||||
if exists == false {
|
||||
return false
|
||||
}
|
||||
for _, a := range(actions) {
|
||||
if a == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type Actions []string
|
||||
|
||||
func (actions Actions) Allows(action string) bool {
|
||||
for _, a := range(actions) {
|
||||
if a == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type NodeActions map[NodeID]Actions
|
||||
|
||||
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
|
||||
return func(ctx *Context, data []byte)(Policy, error){
|
||||
var j PerNodePolicyJSON
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node_actions := NodeActions{}
|
||||
for id_str, actions := range(j.NodeActions) {
|
||||
id, err := ParseID(id_str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = LoadNode(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node_actions[id] = actions
|
||||
}
|
||||
|
||||
return init_fn(node_actions)
|
||||
}
|
||||
}
|
||||
|
||||
func NewChildOfPolicy(node_actions NodeActions) ChildOfPolicy {
|
||||
return ChildOfPolicy{
|
||||
PerNodePolicy: NewPerNodePolicy(node_actions),
|
||||
}
|
||||
}
|
||||
|
||||
const ParentOfPolicyType = PolicyType("PARENT_OF")
|
||||
type ParentOfPolicy struct {
|
||||
PerNodePolicy
|
||||
}
|
||||
func (policy *ParentOfPolicy) Type() PolicyType {
|
||||
return ParentOfPolicyType
|
||||
}
|
||||
|
||||
func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||
context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy)
|
||||
for id, actions := range(policy.NodeActions) {
|
||||
thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent)
|
||||
if thread_ext.Parent != nil {
|
||||
if thread_ext.Parent.ID == principal.ID {
|
||||
for _, a := range(actions) {
|
||||
if a == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy {
|
||||
return ParentOfPolicy{
|
||||
PerNodePolicy: NewPerNodePolicy(node_actions),
|
||||
}
|
||||
}
|
||||
|
||||
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
|
||||
if node_actions == nil {
|
||||
node_actions = NodeActions{}
|
||||
}
|
||||
|
||||
return PerNodePolicy{
|
||||
NodeActions: node_actions,
|
||||
}
|
||||
}
|
||||
|
||||
type PerNodePolicy struct {
|
||||
NodeActions NodeActions
|
||||
}
|
||||
|
||||
type PerNodePolicyJSON struct {
|
||||
NodeActions map[string][]string `json:"node_actions"`
|
||||
}
|
||||
|
||||
const PerNodePolicyType = PolicyType("PER_NODE")
|
||||
func (policy *PerNodePolicy) Type() PolicyType {
|
||||
return PerNodePolicyType
|
||||
}
|
||||
|
||||
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
|
||||
node_actions := map[string][]string{}
|
||||
for id, actions := range(policy.NodeActions) {
|
||||
node_actions[id.String()] = actions
|
||||
}
|
||||
|
||||
return json.MarshalIndent(&PerNodePolicyJSON{
|
||||
NodeActions: node_actions,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||
for id, actions := range(policy.NodeActions) {
|
||||
if id != principal.ID {
|
||||
continue
|
||||
}
|
||||
for _, a := range(actions) {
|
||||
if a == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
// Extension to allow a node to hold ACL policies
|
||||
type ACLPolicyExt struct {
|
||||
Policies map[PolicyType]Policy
|
||||
}
|
||||
|
||||
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
|
||||
type ACLExt struct {
|
||||
Delegations NodeMap
|
||||
}
|
||||
|
||||
func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j struct {
|
||||
Delegations []string `json:"delegation"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delegations, err := RestoreNodeList(ctx, j.Delegations)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ACLExt{
|
||||
Delegations: delegations,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NodeList(nodes ...*Node) NodeMap {
|
||||
m := NodeMap{}
|
||||
for _, node := range(nodes) {
|
||||
m[node.ID] = node
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func NewACLExt(delegations NodeMap) *ACLExt {
|
||||
if delegations == nil {
|
||||
delegations = NodeMap{}
|
||||
}
|
||||
|
||||
return &ACLExt{
|
||||
Delegations: delegations,
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ACLExt) Serialize() ([]byte, error) {
|
||||
delegations := make([]string, len(ext.Delegations))
|
||||
i := 0
|
||||
for id, _ := range(ext.Delegations) {
|
||||
delegations[i] = id.String()
|
||||
i += 1
|
||||
}
|
||||
|
||||
return json.MarshalIndent(&struct{
|
||||
Delegations []string `json:"delegations"`
|
||||
}{
|
||||
Delegations: delegations,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
const ACLExtType = ExtType("ACL")
|
||||
func (ext *ACLExt) Type() ExtType {
|
||||
return ACLExtType
|
||||
}
|
||||
|
||||
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
|
||||
type PolicyInfo struct {
|
||||
Load PolicyLoadFunc
|
||||
}
|
||||
|
||||
type ACLPolicyExtContext struct {
|
||||
Types map[PolicyType]PolicyInfo
|
||||
}
|
||||
|
||||
func NewACLPolicyExtContext() *ACLPolicyExtContext {
|
||||
return &ACLPolicyExtContext{
|
||||
Types: map[PolicyType]PolicyInfo{
|
||||
PerNodePolicyType: PolicyInfo{
|
||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||
policy := NewPerNodePolicy(nodes)
|
||||
return &policy, nil
|
||||
}),
|
||||
},
|
||||
ParentOfPolicyType: PolicyInfo{
|
||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||
policy := NewParentOfPolicy(nodes)
|
||||
return &policy, nil
|
||||
}),
|
||||
},
|
||||
ChildOfPolicyType: PolicyInfo{
|
||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||
policy := NewChildOfPolicy(nodes)
|
||||
return &policy, nil
|
||||
}),
|
||||
},
|
||||
RequirementOfPolicyType: PolicyInfo{
|
||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||
policy := NewRequirementOfPolicy(nodes)
|
||||
return &policy, nil
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ACLPolicyExt) Serialize() ([]byte, error) {
|
||||
policies := map[string][]byte{}
|
||||
for name, policy := range(ext.Policies) {
|
||||
ser, err := policy.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policies[string(name)] = ser
|
||||
}
|
||||
|
||||
return json.MarshalIndent(&struct{
|
||||
Policies map[string][]byte `json:"policies"`
|
||||
}{
|
||||
Policies: policies,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt {
|
||||
if policies == nil {
|
||||
policies = map[PolicyType]Policy{}
|
||||
}
|
||||
|
||||
return &ACLPolicyExt{
|
||||
Policies: policies,
|
||||
}
|
||||
}
|
||||
|
||||
func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j struct {
|
||||
Policies map[string][]byte `json:"policies"`
|
||||
}
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policies := map[PolicyType]Policy{}
|
||||
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(*ACLPolicyExtContext)
|
||||
for name, ser := range(j.Policies) {
|
||||
policy_def, exists := acl_ctx.Types[PolicyType(name)]
|
||||
if exists == false {
|
||||
return nil, fmt.Errorf("%s is not a known policy type", name)
|
||||
}
|
||||
policy, err := policy_def.Load(ctx, ser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policies[PolicyType(name)] = policy
|
||||
}
|
||||
|
||||
return NewACLPolicyExt(policies), nil
|
||||
}
|
||||
|
||||
const ACLPolicyExtType = ExtType("ACL_POLICIES")
|
||||
func (ext *ACLPolicyExt) Type() ExtType {
|
||||
return ACLPolicyExtType
|
||||
}
|
||||
|
||||
// Check if the extension allows the principal to perform action on node
|
||||
func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||
context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
|
||||
for _, policy := range(ext.Policies) {
|
||||
context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy)
|
||||
if policy.Allows(context, principal, action, node) == true {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@ -1,744 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"math"
|
||||
)
|
||||
|
||||
type SerializedType uint64
|
||||
|
||||
func (t SerializedType) String() string {
|
||||
return fmt.Sprintf("0x%x", uint64(t))
|
||||
}
|
||||
|
||||
type ExtType SerializedType
|
||||
|
||||
func (t ExtType) String() string {
|
||||
return fmt.Sprintf("0x%x", uint64(t))
|
||||
}
|
||||
|
||||
type NodeType SerializedType
|
||||
|
||||
func (t NodeType) String() string {
|
||||
return fmt.Sprintf("0x%x", uint64(t))
|
||||
}
|
||||
|
||||
type SignalType SerializedType
|
||||
|
||||
func (t SignalType) String() string {
|
||||
return fmt.Sprintf("0x%x", uint64(t))
|
||||
}
|
||||
|
||||
type FieldTag SerializedType
|
||||
|
||||
func (t FieldTag) String() string {
|
||||
return fmt.Sprintf("0x%x", uint64(t))
|
||||
}
|
||||
|
||||
func NodeTypeFor(name string) NodeType {
|
||||
digest := []byte("GRAPHVENT_NODE - " + name)
|
||||
|
||||
hash := sha512.Sum512(digest)
|
||||
return NodeType(binary.BigEndian.Uint64(hash[0:8]))
|
||||
}
|
||||
|
||||
func SerializeType(t fmt.Stringer) SerializedType {
|
||||
digest := []byte(t.String())
|
||||
hash := sha512.Sum512(digest)
|
||||
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
||||
}
|
||||
|
||||
func SerializedTypeFor[T any]() SerializedType {
|
||||
return SerializeType(reflect.TypeFor[T]())
|
||||
}
|
||||
|
||||
func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType {
|
||||
return ExtType(SerializedTypeFor[E]())
|
||||
}
|
||||
|
||||
func ExtTypeOf(t reflect.Type) ExtType {
|
||||
return ExtType(SerializeType(t.Elem()))
|
||||
}
|
||||
|
||||
func SignalTypeFor[S Signal]() SignalType {
|
||||
return SignalType(SerializedTypeFor[S]())
|
||||
}
|
||||
|
||||
func Hash(base, data string) SerializedType {
|
||||
digest := []byte(base + ":" + data)
|
||||
hash := sha512.Sum512(digest)
|
||||
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
||||
}
|
||||
|
||||
func GetFieldTag(tag string) FieldTag {
|
||||
return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag))
|
||||
}
|
||||
|
||||
func TypeStack(ctx *Context, t reflect.Type, data []byte) (int, error) {
|
||||
info, registered := ctx.Types[t]
|
||||
if registered {
|
||||
binary.BigEndian.PutUint64(data, uint64(info.Serialized))
|
||||
return 8, nil
|
||||
} else {
|
||||
switch t.Kind() {
|
||||
case reflect.Map:
|
||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Map)))
|
||||
|
||||
key_written, err := TypeStack(ctx, t.Key(), data[8:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8 + key_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 8 + key_written + elem_written, nil
|
||||
|
||||
case reflect.Pointer:
|
||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Pointer)))
|
||||
|
||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 8 + elem_written, nil
|
||||
case reflect.Slice:
|
||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Slice)))
|
||||
|
||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 8 + elem_written, nil
|
||||
case reflect.Array:
|
||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Array)))
|
||||
binary.BigEndian.PutUint64(data[8:], uint64(t.Len()))
|
||||
|
||||
elem_written, err := TypeStack(ctx, t.Elem(), data[16:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 16 + elem_written, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("Hit %s, which is not a registered type", t.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
|
||||
first_bytes, left := split(stack, 8)
|
||||
first := SerializedType(binary.BigEndian.Uint64(first_bytes))
|
||||
|
||||
info, registered := ctx.TypesReverse[first]
|
||||
if registered {
|
||||
return info.Reflect, left, nil
|
||||
} else {
|
||||
switch first {
|
||||
case SerializeType(reflect.Map):
|
||||
key_type, after_key, err := UnwrapStack(ctx, left)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
elem_type, after_elem, err := UnwrapStack(ctx, after_key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return reflect.MapOf(key_type, elem_type), after_elem, nil
|
||||
case SerializeType(reflect.Pointer):
|
||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return reflect.PointerTo(elem_type), rest, nil
|
||||
case SerializeType(reflect.Slice):
|
||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return reflect.SliceOf(elem_type), rest, nil
|
||||
case SerializeType(reflect.Array):
|
||||
length_bytes, left := split(left, 8)
|
||||
length := int(binary.BigEndian.Uint64(length_bytes))
|
||||
|
||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return reflect.ArrayOf(length, elem_type), rest, nil
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("Type stack %+v not recognized", stack)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Serialize[T any](ctx *Context, value T, data []byte) (int, error) {
|
||||
return SerializeValue(ctx, reflect.ValueOf(&value).Elem(), data)
|
||||
}
|
||||
|
||||
func Deserialize[T any](ctx *Context, data []byte) (T, error) {
|
||||
reflect_type := reflect.TypeFor[T]()
|
||||
var zero T
|
||||
value, left, err := DeserializeValue(ctx, data, reflect_type)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
} else if len(left) != 0 {
|
||||
return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
|
||||
} else if value.Type() != reflect_type {
|
||||
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
|
||||
}
|
||||
|
||||
return value.Interface().(T), nil
|
||||
}
|
||||
|
||||
func SerializedSize(ctx *Context, value reflect.Value) (int, error) {
|
||||
var sizefn SerializedSizeFn = nil
|
||||
|
||||
info, registered := ctx.Types[value.Type()]
|
||||
if registered {
|
||||
sizefn = info.SerializedSize
|
||||
}
|
||||
|
||||
if sizefn == nil {
|
||||
switch value.Type().Kind() {
|
||||
case reflect.Bool:
|
||||
return 1, nil
|
||||
|
||||
case reflect.Int8:
|
||||
return 1, nil
|
||||
case reflect.Int16:
|
||||
return 2, nil
|
||||
case reflect.Int32:
|
||||
return 4, nil
|
||||
case reflect.Int64:
|
||||
fallthrough
|
||||
case reflect.Int:
|
||||
return 8, nil
|
||||
|
||||
case reflect.Uint8:
|
||||
return 1, nil
|
||||
case reflect.Uint16:
|
||||
return 2, nil
|
||||
case reflect.Uint32:
|
||||
return 4, nil
|
||||
case reflect.Uint64:
|
||||
fallthrough
|
||||
case reflect.Uint:
|
||||
return 8, nil
|
||||
|
||||
case reflect.Float32:
|
||||
return 4, nil
|
||||
case reflect.Float64:
|
||||
return 8, nil
|
||||
|
||||
case reflect.String:
|
||||
return 8 + value.Len(), nil
|
||||
|
||||
case reflect.Pointer:
|
||||
if value.IsNil() {
|
||||
return 1, nil
|
||||
} else {
|
||||
elem_len, err := SerializedSize(ctx, value.Elem())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return 1 + elem_len, nil
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
if value.IsNil() {
|
||||
return 1, nil
|
||||
} else {
|
||||
elem_total := 0
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
elem_len, err := SerializedSize(ctx, value.Index(i))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
elem_total += elem_len
|
||||
}
|
||||
return 9 + elem_total, nil
|
||||
}
|
||||
|
||||
case reflect.Array:
|
||||
total := 0
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
elem_len, err := SerializedSize(ctx, value.Index(i))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total += elem_len
|
||||
}
|
||||
return total, nil
|
||||
|
||||
case reflect.Map:
|
||||
if value.IsNil() {
|
||||
return 1, nil
|
||||
} else {
|
||||
key := reflect.New(value.Type().Key()).Elem()
|
||||
val := reflect.New(value.Type().Elem()).Elem()
|
||||
iter := value.MapRange()
|
||||
|
||||
total := 0
|
||||
for iter.Next() {
|
||||
key.SetIterKey(iter)
|
||||
k, err := SerializedSize(ctx, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total += k
|
||||
|
||||
val.SetIterValue(iter)
|
||||
v, err := SerializedSize(ctx, val)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total += v
|
||||
}
|
||||
|
||||
return 9 + total, nil
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
if registered == false {
|
||||
return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type())
|
||||
} else {
|
||||
field_total := 0
|
||||
for _, field_info := range(info.Fields) {
|
||||
field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
field_total += 8
|
||||
field_total += field_size
|
||||
}
|
||||
|
||||
return 8 + field_total, nil
|
||||
}
|
||||
|
||||
case reflect.Interface:
|
||||
// TODO get size of TypeStack instead of just using 128
|
||||
elem_size, err := SerializedSize(ctx, value.Elem())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 128 + elem_size, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
||||
}
|
||||
} else {
|
||||
return sizefn(ctx, value)
|
||||
}
|
||||
}
|
||||
|
||||
func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) {
|
||||
var serialize SerializeFn = nil
|
||||
|
||||
info, registered := ctx.Types[value.Type()]
|
||||
if registered {
|
||||
serialize = info.Serialize
|
||||
}
|
||||
|
||||
if serialize == nil {
|
||||
switch value.Type().Kind() {
|
||||
case reflect.Bool:
|
||||
if value.Bool() {
|
||||
data[0] = 0xFF
|
||||
} else {
|
||||
data[0] = 0x00
|
||||
}
|
||||
return 1, nil
|
||||
|
||||
case reflect.Int8:
|
||||
data[0] = byte(value.Int())
|
||||
return 1, nil
|
||||
case reflect.Int16:
|
||||
binary.BigEndian.PutUint16(data, uint16(value.Int()))
|
||||
return 2, nil
|
||||
case reflect.Int32:
|
||||
binary.BigEndian.PutUint32(data, uint32(value.Int()))
|
||||
return 4, nil
|
||||
case reflect.Int64:
|
||||
fallthrough
|
||||
case reflect.Int:
|
||||
binary.BigEndian.PutUint64(data, uint64(value.Int()))
|
||||
return 8, nil
|
||||
|
||||
case reflect.Uint8:
|
||||
data[0] = byte(value.Uint())
|
||||
return 1, nil
|
||||
case reflect.Uint16:
|
||||
binary.BigEndian.PutUint16(data, uint16(value.Uint()))
|
||||
return 2, nil
|
||||
case reflect.Uint32:
|
||||
binary.BigEndian.PutUint32(data, uint32(value.Uint()))
|
||||
return 4, nil
|
||||
case reflect.Uint64:
|
||||
fallthrough
|
||||
case reflect.Uint:
|
||||
binary.BigEndian.PutUint64(data, value.Uint())
|
||||
return 8, nil
|
||||
|
||||
case reflect.Float32:
|
||||
binary.BigEndian.PutUint32(data, math.Float32bits(float32(value.Float())))
|
||||
return 4, nil
|
||||
case reflect.Float64:
|
||||
binary.BigEndian.PutUint64(data, math.Float64bits(value.Float()))
|
||||
return 8, nil
|
||||
|
||||
case reflect.String:
|
||||
binary.BigEndian.PutUint64(data, uint64(value.Len()))
|
||||
copy(data[8:], []byte(value.String()))
|
||||
return 8 + value.Len(), nil
|
||||
|
||||
case reflect.Pointer:
|
||||
if value.IsNil() {
|
||||
data[0] = 0x00
|
||||
return 1, nil
|
||||
} else {
|
||||
data[0] = 0x01
|
||||
written, err := SerializeValue(ctx, value.Elem(), data[1:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 1 + written, nil
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
if value.IsNil() {
|
||||
data[0] = 0x00
|
||||
return 8, nil
|
||||
} else {
|
||||
data[0] = 0x01
|
||||
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
||||
total_written := 0
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
written, err := SerializeValue(ctx, value.Index(i), data[9+total_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total_written += written
|
||||
}
|
||||
return 9 + total_written, nil
|
||||
}
|
||||
|
||||
case reflect.Array:
|
||||
total_written := 0
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
written, err := SerializeValue(ctx, value.Index(i), data[total_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total_written += written
|
||||
}
|
||||
return total_written, nil
|
||||
|
||||
case reflect.Map:
|
||||
if value.IsNil() {
|
||||
data[0] = 0x00
|
||||
return 1, nil
|
||||
} else {
|
||||
data[0] = 0x01
|
||||
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
||||
|
||||
key := reflect.New(value.Type().Key()).Elem()
|
||||
val := reflect.New(value.Type().Elem()).Elem()
|
||||
iter := value.MapRange()
|
||||
total_written := 0
|
||||
for iter.Next() {
|
||||
key.SetIterKey(iter)
|
||||
val.SetIterValue(iter)
|
||||
|
||||
k, err := SerializeValue(ctx, key, data[9+total_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total_written += k
|
||||
|
||||
v, err := SerializeValue(ctx, val, data[9+total_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total_written += v
|
||||
}
|
||||
return 9 + total_written, nil
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
if registered == false {
|
||||
return 0, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type())
|
||||
} else {
|
||||
binary.BigEndian.PutUint64(data, uint64(len(info.Fields)))
|
||||
|
||||
total_written := 0
|
||||
for field_tag, field_info := range(info.Fields) {
|
||||
binary.BigEndian.PutUint64(data[8+total_written:], uint64(field_tag))
|
||||
total_written += 8
|
||||
written, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index), data[8+total_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total_written += written
|
||||
}
|
||||
return 8 + total_written, nil
|
||||
}
|
||||
|
||||
case reflect.Interface:
|
||||
type_written, err := TypeStack(ctx, value.Elem().Type(), data)
|
||||
|
||||
elem_written, err := SerializeValue(ctx, value.Elem(), data[type_written:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return type_written + elem_written, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
||||
}
|
||||
} else {
|
||||
return serialize(ctx, value, data)
|
||||
}
|
||||
}
|
||||
|
||||
func split(data []byte, n int) ([]byte, []byte) {
|
||||
return data[:n], data[n:]
|
||||
}
|
||||
|
||||
func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) {
|
||||
var deserialize DeserializeFn = nil
|
||||
|
||||
info, registered := ctx.Types[t]
|
||||
if registered {
|
||||
deserialize = info.Deserialize
|
||||
}
|
||||
|
||||
if deserialize == nil {
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
used, left := split(data, 1)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetBool(used[0] != 0x00)
|
||||
return value, left, nil
|
||||
|
||||
case reflect.Int8:
|
||||
used, left := split(data, 1)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetInt(int64(used[0]))
|
||||
return value, left, nil
|
||||
case reflect.Int16:
|
||||
used, left := split(data, 2)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetInt(int64(binary.BigEndian.Uint16(used)))
|
||||
return value, left, nil
|
||||
case reflect.Int32:
|
||||
used, left := split(data, 4)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetInt(int64(binary.BigEndian.Uint32(used)))
|
||||
return value, left, nil
|
||||
case reflect.Int64:
|
||||
fallthrough
|
||||
case reflect.Int:
|
||||
used, left := split(data, 8)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetInt(int64(binary.BigEndian.Uint64(used)))
|
||||
return value, left, nil
|
||||
|
||||
case reflect.Uint8:
|
||||
used, left := split(data, 1)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetUint(uint64(used[0]))
|
||||
return value, left, nil
|
||||
case reflect.Uint16:
|
||||
used, left := split(data, 2)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetUint(uint64(binary.BigEndian.Uint16(used)))
|
||||
return value, left, nil
|
||||
case reflect.Uint32:
|
||||
used, left := split(data, 4)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetUint(uint64(binary.BigEndian.Uint32(used)))
|
||||
return value, left, nil
|
||||
case reflect.Uint64:
|
||||
fallthrough
|
||||
case reflect.Uint:
|
||||
used, left := split(data, 8)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetUint(binary.BigEndian.Uint64(used))
|
||||
return value, left, nil
|
||||
|
||||
case reflect.Float32:
|
||||
used, left := split(data, 4)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetFloat(float64(math.Float32frombits(binary.BigEndian.Uint32(used))))
|
||||
return value, left, nil
|
||||
case reflect.Float64:
|
||||
used, left := split(data, 8)
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetFloat(math.Float64frombits(binary.BigEndian.Uint64(used)))
|
||||
return value, left, nil
|
||||
|
||||
case reflect.String:
|
||||
length, after_len := split(data, 8)
|
||||
used, left := split(after_len, int(binary.BigEndian.Uint64(length)))
|
||||
value := reflect.New(t).Elem()
|
||||
value.SetString(string(used))
|
||||
return value, left, nil
|
||||
|
||||
case reflect.Pointer:
|
||||
flags, after_flags := split(data, 1)
|
||||
value := reflect.New(t).Elem()
|
||||
if flags[0] == 0x00 {
|
||||
value.SetZero()
|
||||
return value, after_flags, nil
|
||||
} else {
|
||||
elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem())
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
value.Set(elem_value.Addr())
|
||||
return value, after_elem, nil
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
nil_byte := data[0]
|
||||
data = data[1:]
|
||||
if nil_byte == 0x00 {
|
||||
return reflect.New(t).Elem(), data, nil
|
||||
} else {
|
||||
len_bytes, left := split(data, 8)
|
||||
length := int(binary.BigEndian.Uint64(len_bytes))
|
||||
value := reflect.MakeSlice(t, length, length)
|
||||
for i := 0; i < length; i++ {
|
||||
var elem_value reflect.Value
|
||||
var err error
|
||||
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
value.Index(i).Set(elem_value)
|
||||
}
|
||||
return value, left, nil
|
||||
}
|
||||
|
||||
case reflect.Array:
|
||||
value := reflect.New(t).Elem()
|
||||
left := data
|
||||
for i := 0; i < t.Len(); i++ {
|
||||
var elem_value reflect.Value
|
||||
var err error
|
||||
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
value.Index(i).Set(elem_value)
|
||||
}
|
||||
return value, left, nil
|
||||
|
||||
case reflect.Map:
|
||||
flags, after_flags := split(data, 1)
|
||||
if flags[0] == 0x00 {
|
||||
return reflect.New(t).Elem(), after_flags, nil
|
||||
} else {
|
||||
len_bytes, left := split(after_flags, 8)
|
||||
length := int(binary.BigEndian.Uint64(len_bytes))
|
||||
|
||||
value := reflect.MakeMapWithSize(t, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
var key_value reflect.Value
|
||||
var val_value reflect.Value
|
||||
var err error
|
||||
|
||||
key_value, left, err = DeserializeValue(ctx, left, t.Key())
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
|
||||
val_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
|
||||
value.SetMapIndex(key_value, val_value)
|
||||
}
|
||||
|
||||
return value, left, nil
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
info, mapped := ctx.Types[t]
|
||||
if mapped {
|
||||
value := reflect.New(t).Elem()
|
||||
|
||||
num_field_bytes, left := split(data, 8)
|
||||
num_fields := int(binary.BigEndian.Uint64(num_field_bytes))
|
||||
|
||||
for i := 0; i < num_fields; i++ {
|
||||
var tag_bytes []byte
|
||||
|
||||
tag_bytes, left = split(left, 8)
|
||||
field_tag := FieldTag(binary.BigEndian.Uint64(tag_bytes))
|
||||
|
||||
field_info, mapped := info.Fields[field_tag]
|
||||
if mapped {
|
||||
var field_val reflect.Value
|
||||
var err error
|
||||
field_val, left, err = DeserializeValue(ctx, left, field_info.Type)
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
value.FieldByIndex(field_info.Index).Set(field_val)
|
||||
} else {
|
||||
return reflect.Value{}, nil, fmt.Errorf("Unknown field %s on struct %s", field_tag, t)
|
||||
}
|
||||
}
|
||||
if info.PostDeserializeIndex != -1 {
|
||||
post_deserialize_method := value.Addr().Method(info.PostDeserializeIndex)
|
||||
post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
||||
}
|
||||
return value, left, nil
|
||||
} else {
|
||||
return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t)
|
||||
}
|
||||
|
||||
case reflect.Interface:
|
||||
elem_type, rest, err := UnwrapStack(ctx, data)
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
|
||||
elem_val, left, err := DeserializeValue(ctx, rest, elem_type)
|
||||
if err != nil {
|
||||
return reflect.Value{}, nil, err
|
||||
}
|
||||
|
||||
val := reflect.New(t).Elem()
|
||||
val.Set(elem_val)
|
||||
|
||||
return val, left, nil
|
||||
|
||||
default:
|
||||
return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t)
|
||||
}
|
||||
} else {
|
||||
return deserialize(ctx, data)
|
||||
}
|
||||
}
|
||||
|
@ -1,176 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"reflect"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func testTypeStack[T any](t *testing.T, ctx *Context) {
|
||||
buffer := [1024]byte{}
|
||||
reflect_type := reflect.TypeFor[T]()
|
||||
written, err := TypeStack(ctx, reflect_type, buffer[:])
|
||||
fatalErr(t, err)
|
||||
|
||||
stack := buffer[:written]
|
||||
|
||||
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
|
||||
|
||||
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
|
||||
fatalErr(t, err)
|
||||
|
||||
if len(rest) != 0 {
|
||||
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
|
||||
}
|
||||
|
||||
if unwrapped_type != reflect_type {
|
||||
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
|
||||
}
|
||||
|
||||
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
|
||||
}
|
||||
|
||||
func TestSerializeTypes(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test"})
|
||||
|
||||
testTypeStack[int](t, ctx)
|
||||
testTypeStack[map[int]string](t, ctx)
|
||||
testTypeStack[string](t, ctx)
|
||||
testTypeStack[*string](t, ctx)
|
||||
testTypeStack[*map[string]*map[*string]int](t, ctx)
|
||||
testTypeStack[[5]int](t, ctx)
|
||||
testTypeStack[uuid.UUID](t, ctx)
|
||||
testTypeStack[NodeID](t, ctx)
|
||||
}
|
||||
|
||||
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
|
||||
buffer := [1024]byte{}
|
||||
written, err := Serialize(ctx, value, buffer[:])
|
||||
fatalErr(t, err)
|
||||
|
||||
serialized := buffer[:written]
|
||||
|
||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||
|
||||
deserialized, err := Deserialize[T](ctx, serialized)
|
||||
fatalErr(t, err)
|
||||
|
||||
if value != deserialized {
|
||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||
}
|
||||
|
||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||
}
|
||||
|
||||
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
|
||||
buffer := [1024]byte{}
|
||||
written, err := Serialize(ctx, value, buffer[:])
|
||||
fatalErr(t, err)
|
||||
|
||||
serialized := buffer[:written]
|
||||
|
||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
|
||||
|
||||
deserialized, err := Deserialize[L](ctx, serialized)
|
||||
fatalErr(t, err)
|
||||
|
||||
for i, item := range(value) {
|
||||
if item != deserialized[i] {
|
||||
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||
}
|
||||
|
||||
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
|
||||
buffer := [1024]byte{}
|
||||
|
||||
written, err := Serialize(ctx, value, buffer[:])
|
||||
fatalErr(t, err)
|
||||
|
||||
serialized := buffer[:written]
|
||||
|
||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
|
||||
|
||||
deserialized, err := Deserialize[P](ctx, serialized)
|
||||
fatalErr(t, err)
|
||||
|
||||
if value == nil && deserialized == nil {
|
||||
ctx.Log.Logf("test", "Deserialized nil")
|
||||
} else if value == nil {
|
||||
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
|
||||
} else if deserialized == nil {
|
||||
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
|
||||
} else if *deserialized != *value {
|
||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||
} else {
|
||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
|
||||
}
|
||||
}
|
||||
|
||||
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
|
||||
buffer := [1024]byte{}
|
||||
written, err := Serialize(ctx, value, buffer[:])
|
||||
fatalErr(t, err)
|
||||
|
||||
serialized := buffer[:written]
|
||||
|
||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||
|
||||
deserialized, err := Deserialize[T](ctx, serialized)
|
||||
fatalErr(t, err)
|
||||
|
||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||
}
|
||||
|
||||
func TestSerializeValues(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test"})
|
||||
|
||||
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
|
||||
|
||||
testSerializeCompare[int8](t, ctx, -64)
|
||||
testSerializeCompare[int16](t, ctx, -64)
|
||||
testSerializeCompare[int32](t, ctx, -64)
|
||||
testSerializeCompare[int64](t, ctx, -64)
|
||||
testSerializeCompare[int](t, ctx, -64)
|
||||
|
||||
testSerializeCompare[uint8](t, ctx, 64)
|
||||
testSerializeCompare[uint16](t, ctx, 64)
|
||||
testSerializeCompare[uint32](t, ctx, 64)
|
||||
testSerializeCompare[uint64](t, ctx, 64)
|
||||
testSerializeCompare[uint](t, ctx, 64)
|
||||
|
||||
testSerializeCompare[string](t, ctx, "test")
|
||||
|
||||
a := 12
|
||||
testSerializePointer[*int](t, ctx, &a)
|
||||
|
||||
b := "test"
|
||||
testSerializePointer[*string](t, ctx, nil)
|
||||
testSerializePointer[*string](t, ctx, &b)
|
||||
|
||||
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
|
||||
|
||||
testSerializeCompare[bool](t, ctx, true)
|
||||
testSerializeCompare[bool](t, ctx, false)
|
||||
testSerializeCompare[int](t, ctx, -1)
|
||||
testSerializeCompare[uint](t, ctx, 1)
|
||||
testSerializeCompare[NodeID](t, ctx, RandID())
|
||||
testSerializeCompare[*int](t, ctx, nil)
|
||||
testSerializeCompare(t, ctx, "string")
|
||||
|
||||
testSerialize(t, ctx, map[string]string{
|
||||
"Test": "Test",
|
||||
"key": "String",
|
||||
"": "",
|
||||
})
|
||||
|
||||
testSerialize[map[string]string](t, ctx, nil)
|
||||
|
||||
testSerialize(t, ctx, NewListenerExt(10))
|
||||
|
||||
node, err := ctx.NewNode(nil, "Node")
|
||||
fatalErr(t, err)
|
||||
testSerialize(t, ctx, node)
|
||||
}
|
@ -0,0 +1,736 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"sync"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
|
||||
type ThreadActions map[string]ThreadAction
|
||||
type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error)
|
||||
type ThreadHandlers map[SignalType]ThreadHandler
|
||||
|
||||
type InfoType string
|
||||
func (t InfoType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
type Info interface {
|
||||
Serializable[InfoType]
|
||||
}
|
||||
|
||||
// Data required by a parent thread to restore it's children
|
||||
type ParentInfo struct {
|
||||
Start bool `json:"start"`
|
||||
StartAction string `json:"start_action"`
|
||||
RestoreAction string `json:"restore_action"`
|
||||
}
|
||||
|
||||
const ParentInfoType = InfoType("PARENT")
|
||||
func (info *ParentInfo) Type() InfoType {
|
||||
return ParentInfoType
|
||||
}
|
||||
|
||||
func (info *ParentInfo) Serialize() ([]byte, error) {
|
||||
return json.MarshalIndent(info, "", " ")
|
||||
}
|
||||
|
||||
type QueuedAction struct {
|
||||
Timeout time.Time `json:"time"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
type ThreadType string
|
||||
func (thread ThreadType) Hash() uint64 {
|
||||
hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread))))
|
||||
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
|
||||
}
|
||||
|
||||
type ThreadInfo struct {
|
||||
Actions ThreadActions
|
||||
Handlers ThreadHandlers
|
||||
}
|
||||
|
||||
type InfoLoadFunc func([]byte)(Info, error)
|
||||
type ThreadExtContext struct {
|
||||
Types map[ThreadType]ThreadInfo
|
||||
Loads map[InfoType]InfoLoadFunc
|
||||
}
|
||||
|
||||
const BaseThreadType = ThreadType("BASE")
|
||||
func NewThreadExtContext() *ThreadExtContext {
|
||||
return &ThreadExtContext{
|
||||
Types: map[ThreadType]ThreadInfo{
|
||||
BaseThreadType: ThreadInfo{
|
||||
Actions: BaseThreadActions,
|
||||
Handlers: BaseThreadHandlers,
|
||||
},
|
||||
},
|
||||
Loads: map[InfoType]InfoLoadFunc{
|
||||
ParentInfoType: func(data []byte) (Info, error) {
|
||||
var info ParentInfo
|
||||
err := json.Unmarshal(data, &info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error {
|
||||
if actions == nil || handlers == nil {
|
||||
return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type)
|
||||
}
|
||||
|
||||
_, exists := ctx.Types[thread_type]
|
||||
if exists == true {
|
||||
return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type)
|
||||
}
|
||||
ctx.Types[thread_type] = ThreadInfo{
|
||||
Actions: actions,
|
||||
Handlers: handlers,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error {
|
||||
if load_fn == nil {
|
||||
return fmt.Errorf("Cannot register %s with nil load_fn", info_type)
|
||||
}
|
||||
|
||||
_, exists := ctx.Loads[info_type]
|
||||
if exists == true {
|
||||
return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type)
|
||||
}
|
||||
|
||||
ctx.Loads[info_type] = load_fn
|
||||
return nil
|
||||
}
|
||||
|
||||
type ThreadExt struct {
|
||||
Actions ThreadActions
|
||||
Handlers ThreadHandlers
|
||||
|
||||
ThreadType ThreadType
|
||||
|
||||
SignalChan chan Signal
|
||||
TimeoutChan <-chan time.Time
|
||||
|
||||
ChildWaits sync.WaitGroup
|
||||
|
||||
ActiveLock sync.Mutex
|
||||
Active bool
|
||||
State string
|
||||
|
||||
Parent *Node
|
||||
Children map[NodeID]ChildInfo
|
||||
|
||||
ActionQueue []QueuedAction
|
||||
NextAction *QueuedAction
|
||||
}
|
||||
|
||||
type ThreadExtJSON struct {
|
||||
State string `json:"state"`
|
||||
Type string `json:"type"`
|
||||
Parent string `json:"parent"`
|
||||
Children map[string]map[string][]byte `json:"children"`
|
||||
ActionQueue []QueuedAction
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) Serialize() ([]byte, error) {
|
||||
children := map[string]map[string][]byte{}
|
||||
for id, child := range(ext.Children) {
|
||||
id_str := id.String()
|
||||
children[id_str] = map[string][]byte{}
|
||||
for info_type, info := range(child.Infos) {
|
||||
var err error
|
||||
children[id_str][string(info_type)], err = info.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return json.MarshalIndent(&ThreadExtJSON{
|
||||
State: ext.State,
|
||||
Type: string(ext.ThreadType),
|
||||
Parent: SaveNode(ext.Parent),
|
||||
Children: children,
|
||||
ActionQueue: ext.ActionQueue,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) {
|
||||
if children == nil {
|
||||
children = map[NodeID]ChildInfo{}
|
||||
}
|
||||
|
||||
if action_queue == nil {
|
||||
action_queue = []QueuedAction{}
|
||||
}
|
||||
|
||||
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type_info, exists := thread_ctx.Types[thread_type]
|
||||
if exists == false {
|
||||
return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type)
|
||||
}
|
||||
next_action, timeout_chan := SoonestAction(action_queue)
|
||||
|
||||
return &ThreadExt{
|
||||
ThreadType: thread_type,
|
||||
Actions: type_info.Actions,
|
||||
Handlers: type_info.Handlers,
|
||||
SignalChan: make(chan Signal, THREAD_BUFFER_SIZE),
|
||||
TimeoutChan: timeout_chan,
|
||||
Active: false,
|
||||
State: state,
|
||||
Parent: parent,
|
||||
Children: children,
|
||||
ActionQueue: action_queue,
|
||||
NextAction: next_action,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const THREAD_BUFFER_SIZE int = 1024
|
||||
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j ThreadExtJSON
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx.Log.Logf("db", "DB_LOAD_THREAD_EXT_JSON: %+v", j)
|
||||
|
||||
parent, err := RestoreNode(ctx, j.Parent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
children := map[NodeID]ChildInfo{}
|
||||
for id_str, infos := range(j.Children) {
|
||||
child_node, err := RestoreNode(ctx, id_str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
child_infos := map[InfoType]Info{}
|
||||
for info_type_str, info_data := range(infos) {
|
||||
info_type := InfoType(info_type_str)
|
||||
info_load, exists := thread_ctx.Loads[info_type]
|
||||
if exists == false {
|
||||
return nil, fmt.Errorf("%s is not a known InfoType in ThreacExrContxt", info_type)
|
||||
}
|
||||
|
||||
info, err := info_load(info_data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
child_infos[info_type] = info
|
||||
}
|
||||
|
||||
children[child_node.ID] = ChildInfo{
|
||||
Child: child_node,
|
||||
Infos: child_infos,
|
||||
}
|
||||
}
|
||||
|
||||
return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue)
|
||||
}
|
||||
|
||||
const ThreadExtType = ExtType("THREAD")
|
||||
func (ext *ThreadExt) Type() ExtType {
|
||||
return ThreadExtType
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
|
||||
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
|
||||
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) ClearActionQueue() {
|
||||
ext.ActionQueue = []QueuedAction{}
|
||||
ext.NextAction = nil
|
||||
ext.TimeoutChan = nil
|
||||
}
|
||||
|
||||
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
|
||||
var soonest_action *QueuedAction
|
||||
var soonest_time time.Time
|
||||
for _, action := range(actions) {
|
||||
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
|
||||
soonest_action = &action
|
||||
soonest_time = action.Timeout
|
||||
}
|
||||
}
|
||||
if soonest_action != nil {
|
||||
return soonest_action, time.After(time.Until(soonest_action.Timeout))
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) ChildList() []*Node {
|
||||
ret := make([]*Node, len(ext.Children))
|
||||
i := 0
|
||||
for _, info := range(ext.Children) {
|
||||
ret[i] = info.Child
|
||||
i += 1
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Assumed that thread is already locked for signal
|
||||
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
|
||||
|
||||
var err error
|
||||
switch signal.Direction() {
|
||||
case Up:
|
||||
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
||||
if ext.Parent != nil {
|
||||
if ext.Parent.ID != node.ID {
|
||||
return SendSignal(context, ext.Parent, node, signal)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case Down:
|
||||
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
||||
for _, info := range(ext.Children) {
|
||||
err := SendSignal(context, info.Child, node, signal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case Direct:
|
||||
err = nil
|
||||
default:
|
||||
return fmt.Errorf("Invalid signal direction %d", signal.Direction())
|
||||
}
|
||||
ext.SignalChan <- signal
|
||||
return err
|
||||
}
|
||||
|
||||
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
|
||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
child_ext, err := GetExt[*ThreadExt](child)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UpdateStates(context, principal, ACLMap{
|
||||
thread.ID: ACLInfo{thread, []string{"children"}},
|
||||
child.ID: ACLInfo{child, []string{"parent"}},
|
||||
}, func(context *StateContext) error {
|
||||
_, is_child := thread_ext.Children[child.ID]
|
||||
if is_child == false {
|
||||
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
|
||||
}
|
||||
|
||||
delete(thread_ext.Children, child.ID)
|
||||
child_ext.Parent = nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
|
||||
for _, info := range(cur.Children) {
|
||||
child := info.Child
|
||||
if child.ID == id {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
child_ext, err := GetExt[*ThreadExt](child)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var is_child bool
|
||||
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
|
||||
is_child, err = checkIfChild(context, id, child_ext)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if is_child {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Links child to parent with info as the associated info
|
||||
// Continues the write context with princ, getting children for thread and parent for child
|
||||
func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
|
||||
if context == nil || principal == nil || thread == nil || info.Child == nil {
|
||||
return fmt.Errorf("invalid input")
|
||||
}
|
||||
|
||||
child := info.Child
|
||||
if thread.ID == child.ID {
|
||||
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
|
||||
}
|
||||
|
||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
child_ext, err := GetExt[*ThreadExt](child)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UpdateStates(context, principal, ACLMap{
|
||||
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
|
||||
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
|
||||
}, func(context *StateContext) error {
|
||||
if child_ext.Parent != nil {
|
||||
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
|
||||
}
|
||||
|
||||
is_child, err := checkIfChild(context, thread.ID, child_ext)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if is_child == true {
|
||||
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
|
||||
}
|
||||
|
||||
is_child, err = checkIfChild(context, child.ID, thread_ext)
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
} else if is_child == true {
|
||||
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
|
||||
}
|
||||
|
||||
// TODO check for info types
|
||||
|
||||
thread_ext.Children[child.ID] = info
|
||||
child_ext.Parent = thread
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type ChildInfo struct {
|
||||
Child *Node
|
||||
Infos map[InfoType]Info
|
||||
}
|
||||
|
||||
func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo {
|
||||
if infos == nil {
|
||||
infos = map[InfoType]Info{}
|
||||
}
|
||||
|
||||
return ChildInfo{
|
||||
Child: child,
|
||||
Infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) SetActive(active bool) error {
|
||||
ext.ActiveLock.Lock()
|
||||
defer ext.ActiveLock.Unlock()
|
||||
if ext.Active == true && active == true {
|
||||
return fmt.Errorf("alreday active, cannot set active")
|
||||
} else if ext.Active == false && active == false {
|
||||
return fmt.Errorf("already inactive, canot set inactive")
|
||||
}
|
||||
ext.Active = active
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *ThreadExt) SetState(state string) error {
|
||||
ext.State = state
|
||||
return nil
|
||||
}
|
||||
|
||||
// Requires the read permission of threads children
|
||||
func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
|
||||
if thread == nil {
|
||||
panic("cannot recurse through nil")
|
||||
}
|
||||
|
||||
if id == thread.ID {
|
||||
return thread, nil
|
||||
}
|
||||
|
||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var found *Node = nil
|
||||
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||
for _, info := range(thread_ext.Children) {
|
||||
found, err = FindChild(context, principal, info.Child, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if found != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return found, err
|
||||
}
|
||||
|
||||
func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
|
||||
thread_ext.ChildWaits.Add(1)
|
||||
go func(child *Node) {
|
||||
defer thread_ext.ChildWaits.Done()
|
||||
err := ThreadLoop(ctx, child, first_action)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
|
||||
} else {
|
||||
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
|
||||
}
|
||||
}(child)
|
||||
}
|
||||
|
||||
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
|
||||
func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
|
||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
|
||||
|
||||
err = thread_ext.SetActive(true)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
|
||||
return err
|
||||
}
|
||||
next_action := first_action
|
||||
for next_action != "" {
|
||||
action, exists := thread_ext.Actions[next_action]
|
||||
if exists == false {
|
||||
error_str := fmt.Sprintf("%s is not a valid action", next_action)
|
||||
return errors.New(error_str)
|
||||
}
|
||||
|
||||
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
|
||||
next_action, err = action(ctx, thread, thread_ext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = thread_ext.SetActive(false)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
|
||||
context := NewWriteContext(ctx)
|
||||
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||
sig, ok := signal.(IDSignal)
|
||||
if ok == false {
|
||||
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
|
||||
return nil
|
||||
}
|
||||
info, exists := thread_ext.Children[sig.ID]
|
||||
if exists == false {
|
||||
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
|
||||
return nil
|
||||
}
|
||||
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
||||
if exists == false {
|
||||
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
|
||||
}
|
||||
|
||||
if parent_info.Start == true {
|
||||
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
} else {
|
||||
|
||||
}
|
||||
return "wait", nil
|
||||
}
|
||||
|
||||
// Helper function to start a child from a thread during a signal handler
|
||||
// Starts a write context, so cannot be called from either a write or read context
|
||||
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||
sig, ok := signal.(StartChildSignal)
|
||||
if ok == false {
|
||||
return "wait", nil
|
||||
}
|
||||
|
||||
context := NewWriteContext(ctx)
|
||||
return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||
info, exists:= thread_ext.Children[sig.ID]
|
||||
if exists == false {
|
||||
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
|
||||
}
|
||||
|
||||
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
||||
if exists == false {
|
||||
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
|
||||
}
|
||||
parent_info.Start = true
|
||||
ChildGo(ctx, thread_ext, info.Child, sig.Action)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to restore threads that should be running from a parents restore action
|
||||
// Starts a write context, so cannot be called from either a write or read context
|
||||
func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
|
||||
context := NewWriteContext(ctx)
|
||||
return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||
return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"state"}), func(context *StateContext) error {
|
||||
for _, info := range(thread_ext.Children) {
|
||||
child_ext, err := GetExt[*ThreadExt](info.Child)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parent_info := info.Infos[ParentInfoType].(*ParentInfo)
|
||||
if parent_info.Start == true && child_ext.State != "finished" {
|
||||
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
|
||||
if start == true {
|
||||
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
||||
} else {
|
||||
ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to be called during a threads start action, sets the thread state to started
|
||||
// Starts a write context, so cannot be called from either a write or read context
|
||||
// Returns "wait", nil on success, so the first return value can be ignored safely
|
||||
func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||
context := NewWriteContext(ctx)
|
||||
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
||||
err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return thread_ext.SetState("started")
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
context = NewReadContext(ctx)
|
||||
return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID))
|
||||
}
|
||||
|
||||
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
|
||||
for {
|
||||
select {
|
||||
case signal := <- thread_ext.SignalChan:
|
||||
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
|
||||
signal_fn, exists := thread_ext.Handlers[signal.Type()]
|
||||
if exists == true {
|
||||
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
|
||||
return signal_fn(ctx, thread, thread_ext, signal)
|
||||
} else {
|
||||
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
|
||||
}
|
||||
case <- thread_ext.TimeoutChan:
|
||||
timeout_action := ""
|
||||
context := NewWriteContext(ctx)
|
||||
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
|
||||
timeout_action = thread_ext.NextAction.Action
|
||||
thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
|
||||
}
|
||||
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
|
||||
return timeout_action, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||
context := NewWriteContext(ctx)
|
||||
return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
||||
err := thread_ext.SetState("finished")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
||||
})
|
||||
}
|
||||
|
||||
var ThreadAbortedError = errors.New("Thread aborted by signal")
|
||||
|
||||
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
|
||||
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||
context := NewReadContext(ctx)
|
||||
err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", ThreadAbortedError
|
||||
}
|
||||
|
||||
// Default thread action for "stop", sends a signal and returns no error
|
||||
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||
context := NewReadContext(ctx)
|
||||
err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
|
||||
return "finish", err
|
||||
}
|
||||
|
||||
// Default thread actions
|
||||
var BaseThreadActions = ThreadActions{
|
||||
"wait": ThreadWait,
|
||||
"start": ThreadStart,
|
||||
"finish": ThreadFinish,
|
||||
}
|
||||
|
||||
// Default thread signal handlers
|
||||
var BaseThreadHandlers = ThreadHandlers{
|
||||
"abort": ThreadAbort,
|
||||
"stop": ThreadStop,
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"time"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
type ECDHExt struct {
|
||||
Granted time.Time
|
||||
Pubkey *ecdsa.PublicKey
|
||||
Shared []byte
|
||||
}
|
||||
|
||||
type ECDHExtJSON struct {
|
||||
Granted time.Time `json:"granted"`
|
||||
Pubkey []byte `json:"pubkey"`
|
||||
Shared []byte `json:"shared"`
|
||||
}
|
||||
|
||||
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const ECDHExtType = ExtType("ECDH")
|
||||
func (ext *ECDHExt) Type() ExtType {
|
||||
return ECDHExtType
|
||||
}
|
||||
|
||||
func (ext *ECDHExt) Serialize() ([]byte, error) {
|
||||
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.MarshalIndent(&ECDHExtJSON{
|
||||
Granted: ext.Granted,
|
||||
Pubkey: pubkey,
|
||||
Shared: ext.Shared,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j ECDHExtJSON
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pubkey *ecdsa.PublicKey
|
||||
switch pub.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
pubkey = pub.(*ecdsa.PublicKey)
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid key type: %+v", pub)
|
||||
}
|
||||
|
||||
extension := ECDHExt{
|
||||
Granted: j.Granted,
|
||||
Pubkey: pubkey,
|
||||
Shared: j.Shared,
|
||||
}
|
||||
|
||||
return &extension, nil
|
||||
}
|
||||
|
||||
type GroupExt struct {
|
||||
Members NodeMap
|
||||
}
|
||||
|
||||
const GroupExtType = ExtType("GROUP")
|
||||
func (ext *GroupExt) Type() ExtType {
|
||||
return GroupExtType
|
||||
}
|
||||
|
||||
func (ext *GroupExt) Serialize() ([]byte, error) {
|
||||
return json.MarshalIndent(&struct{
|
||||
Members []string `json:"members"`
|
||||
}{
|
||||
Members: SaveNodeList(ext.Members),
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func NewGroupExt(members NodeMap) *GroupExt {
|
||||
if members == nil {
|
||||
members = NodeMap{}
|
||||
}
|
||||
return &GroupExt{
|
||||
Members: members,
|
||||
}
|
||||
}
|
||||
|
||||
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
|
||||
var j struct {
|
||||
Members []string `json:"members"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal(data, &j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
members, err := RestoreNodeList(ctx, j.Members)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewGroupExt(members), nil
|
||||
}
|
||||
|
||||
func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue