Compare commits
No commits in common. "master" and "graph-rework-2" have entirely different histories.
master
...
graph-rewo
@ -1,3 +0,0 @@
|
|||||||
[submodule "graphql"]
|
|
||||||
path = graphql
|
|
||||||
url = https://github.com/graphql-go/graphql
|
|
@ -1,48 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
badger "github.com/dgraph-io/badger/v3"
|
|
||||||
gv "github.com/mekkanized/graphvent"
|
|
||||||
)
|
|
||||||
|
|
||||||
func check(err error) {
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
ctx, err := gv.NewContext(&gv.BadgerDB{
|
|
||||||
DB: db,
|
|
||||||
}, gv.NewConsoleLogger([]string{"test"}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
listener_ext := gv.NewListenerExt(1000)
|
|
||||||
|
|
||||||
n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
_, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
for true {
|
|
||||||
select {
|
|
||||||
case message := <- listener_ext.Chan:
|
|
||||||
fmt.Printf("Listener Message: %+v\n", message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,284 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
badger "github.com/dgraph-io/badger/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Database interface {
|
|
||||||
WriteNodeInit(*Context, *Node) error
|
|
||||||
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
|
|
||||||
LoadNode(*Context, NodeID) (*Node, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
const WRITE_BUFFER_SIZE = 1000000
|
|
||||||
type BadgerDB struct {
|
|
||||||
*badger.DB
|
|
||||||
sync.Mutex
|
|
||||||
buffer [WRITE_BUFFER_SIZE]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
|
|
||||||
if node == nil {
|
|
||||||
return fmt.Errorf("Cannot serialize nil *Node")
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.Update(func(tx *badger.Txn) error {
|
|
||||||
db.Lock()
|
|
||||||
defer db.Unlock()
|
|
||||||
|
|
||||||
// Get the base key bytes
|
|
||||||
id_ser, err := node.ID.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cur := 0
|
|
||||||
|
|
||||||
// Write Node value
|
|
||||||
written, err := Serialize(ctx, node, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Set(id_ser, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cur += written
|
|
||||||
|
|
||||||
// Write empty signal queue
|
|
||||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
|
||||||
written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cur += written
|
|
||||||
|
|
||||||
// Write node extension list
|
|
||||||
ext_list := []ExtType{}
|
|
||||||
for ext_type := range(node.Extensions) {
|
|
||||||
ext_list = append(ext_list, ext_type)
|
|
||||||
}
|
|
||||||
written, err = Serialize(ctx, ext_list, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
|
||||||
err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cur += written
|
|
||||||
|
|
||||||
// For each extension:
|
|
||||||
for ext_type, ext := range(node.Extensions) {
|
|
||||||
ext_info, exists := ctx.Extensions[ext_type]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
|
|
||||||
}
|
|
||||||
|
|
||||||
ext_value := reflect.ValueOf(ext).Elem()
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
|
||||||
|
|
||||||
// Write each field to a seperate key
|
|
||||||
for field_tag, field_info := range(ext_info.Fields) {
|
|
||||||
field_value := ext_value.FieldByIndex(field_info.Index)
|
|
||||||
|
|
||||||
field_id := make([]byte, len(ext_id) + 8)
|
|
||||||
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
|
||||||
copy(field_id, tmp)
|
|
||||||
|
|
||||||
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
|
||||||
}
|
|
||||||
cur += written
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
|
|
||||||
return db.Update(func(tx *badger.Txn) error {
|
|
||||||
db.Lock()
|
|
||||||
defer db.Unlock()
|
|
||||||
|
|
||||||
// Get the base key bytes
|
|
||||||
id_bytes := ([16]byte)(node.ID)
|
|
||||||
|
|
||||||
cur := 0
|
|
||||||
|
|
||||||
// Write the signal queue if it needs to be written
|
|
||||||
if node.writeSignalQueue {
|
|
||||||
node.writeSignalQueue = false
|
|
||||||
|
|
||||||
sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
|
|
||||||
written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
|
|
||||||
}
|
|
||||||
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
|
|
||||||
}
|
|
||||||
cur += written
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each ext in changes
|
|
||||||
for ext_type, changes := range(changes) {
|
|
||||||
ext_info, exists := ctx.Extensions[ext_type]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("%s is not an extension in ctx", ext_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
ext, exists := node.Extensions[ext_type]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
|
|
||||||
}
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
|
|
||||||
ext_value := reflect.ValueOf(ext)
|
|
||||||
|
|
||||||
// Write each field
|
|
||||||
for _, tag := range(changes) {
|
|
||||||
field_info, exists := ext_info.Fields[tag]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
field_value := ext_value.FieldByIndex(field_info.Index)
|
|
||||||
|
|
||||||
field_id := make([]byte, len(ext_id) + 8)
|
|
||||||
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
|
|
||||||
copy(field_id, tmp)
|
|
||||||
|
|
||||||
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
|
||||||
}
|
|
||||||
cur += written
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
|
|
||||||
var node *Node = nil
|
|
||||||
|
|
||||||
err := db.View(func(tx *badger.Txn) error {
|
|
||||||
// Get the base key bytes
|
|
||||||
id_ser, err := id.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to serialize node_id: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the node value
|
|
||||||
node_item, err := tx.Get(id_ser)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = node_item.Value(func(val []byte) error {
|
|
||||||
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
|
|
||||||
node, err = Deserialize[*Node](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the signal queue
|
|
||||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
|
||||||
sigqueue_item, err := tx.Get(sigqueue_id)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to get sigqueue_id: %w", err)
|
|
||||||
}
|
|
||||||
err = sigqueue_item.Value(func(val []byte) error {
|
|
||||||
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the extension list
|
|
||||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
|
||||||
ext_list_item, err := tx.Get(ext_list_id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ext_list []ExtType
|
|
||||||
ext_list_item.Value(func(val []byte) error {
|
|
||||||
ext_list, err = Deserialize[[]ExtType](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
// Get the extensions
|
|
||||||
for _, ext_type := range(ext_list) {
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
|
||||||
ext_info, exists := ctx.Extensions[ext_type]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("Extension %s not in context", ext_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
ext := reflect.New(ext_info.Type)
|
|
||||||
for field_tag, field_info := range(ext_info.Fields) {
|
|
||||||
field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
|
||||||
field_item, err := tx.Get(field_id)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
|
|
||||||
}
|
|
||||||
err = field_item.Value(func(val []byte) error {
|
|
||||||
value, _, err := DeserializeValue(ctx, val, field_info.Type)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ext.Elem().FieldByIndex(field_info.Index).Set(value)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
node.Extensions[ext_type] = ext.Interface().(Extension)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if node == nil {
|
|
||||||
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return node, nil
|
|
||||||
}
|
|
@ -1,16 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
type Tag string
|
|
||||||
type Changes []Tag
|
|
||||||
|
|
||||||
// Extensions are data attached to nodes that process signals
|
|
||||||
type Extension interface {
|
|
||||||
// Called to process incoming signals, returning changes and messages to send
|
|
||||||
Process(*Context, *Node, NodeID, Signal) ([]Message, Changes)
|
|
||||||
|
|
||||||
// Called when the node is loaded into a context(creation or move), so extension data can be initialized
|
|
||||||
Load(*Context, *Node) error
|
|
||||||
|
|
||||||
// Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up
|
|
||||||
Unload(*Context, *Node)
|
|
||||||
}
|
|
@ -0,0 +1,172 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
"reflect"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewField(init func()*graphql.Field) *graphql.Field {
|
||||||
|
return init()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Singleton[K graphql.Type] struct {
|
||||||
|
Type K
|
||||||
|
List *graphql.List
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List)) *Singleton[K] {
|
||||||
|
val := init()
|
||||||
|
list := graphql.NewList(val)
|
||||||
|
if post_init != nil {
|
||||||
|
post_init(val, list)
|
||||||
|
}
|
||||||
|
return &Singleton[K]{
|
||||||
|
Type: val,
|
||||||
|
List: list,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddNodeInterfaceFields(i *graphql.Interface) {
|
||||||
|
i.AddFieldConfig("ID", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
})
|
||||||
|
|
||||||
|
i.AddFieldConfig("TypeHash", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
|
||||||
|
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
||||||
|
if ok == false {
|
||||||
|
return nil, fmt.Errorf("Bad resolve in params context")
|
||||||
|
}
|
||||||
|
return resolve_context, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
|
||||||
|
i := graphql.NewInterface(graphql.InterfaceConfig{
|
||||||
|
Name: "Node",
|
||||||
|
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||||
|
ctx, err := PrepTypeResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_nodes := ctx.GQLContext.ValidNodes
|
||||||
|
p_type := reflect.TypeOf(p.Value)
|
||||||
|
|
||||||
|
for key, value := range(valid_nodes) {
|
||||||
|
if p_type == key {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := p.Value.(Node)
|
||||||
|
if ok == true {
|
||||||
|
return ctx.GQLContext.BaseNodeType
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
AddNodeInterfaceFields(i)
|
||||||
|
|
||||||
|
return i
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
|
||||||
|
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
|
||||||
|
Name: "Lockable",
|
||||||
|
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||||
|
ctx, err := PrepTypeResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_lockables := ctx.GQLContext.ValidLockables
|
||||||
|
p_type := reflect.TypeOf(p.Value)
|
||||||
|
|
||||||
|
for key, value := range(valid_lockables) {
|
||||||
|
if p_type == key {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := p.Value.(*Node)
|
||||||
|
if ok == false {
|
||||||
|
return ctx.GQLContext.BaseLockableType
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
return gql_interface_lockable
|
||||||
|
}, func(lockable *graphql.Interface, lockable_list *graphql.List) {
|
||||||
|
lockable.AddFieldConfig("Requirements", &graphql.Field{
|
||||||
|
Type: lockable_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
lockable.AddFieldConfig("Dependencies", &graphql.Field{
|
||||||
|
Type: lockable_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
lockable.AddFieldConfig("Owner", &graphql.Field{
|
||||||
|
Type: lockable,
|
||||||
|
})
|
||||||
|
AddNodeInterfaceFields(lockable)
|
||||||
|
})
|
||||||
|
|
||||||
|
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
|
||||||
|
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
|
||||||
|
Name: "Thread",
|
||||||
|
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
||||||
|
ctx, err := PrepTypeResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_threads := ctx.GQLContext.ValidThreads
|
||||||
|
p_type := reflect.TypeOf(p.Value)
|
||||||
|
|
||||||
|
for key, value := range(valid_threads) {
|
||||||
|
if p_type == key {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node, ok := p.Value.(*Node)
|
||||||
|
if ok == false {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = GetExt[*ThreadExt](node)
|
||||||
|
if err == nil {
|
||||||
|
return ctx.GQLContext.BaseThreadType
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
return gql_interface_thread
|
||||||
|
}, func(thread *graphql.Interface, thread_list *graphql.List) {
|
||||||
|
thread.AddFieldConfig("Children", &graphql.Field{
|
||||||
|
Type: thread_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
thread.AddFieldConfig("Parent", &graphql.Field{
|
||||||
|
Type: thread,
|
||||||
|
})
|
||||||
|
|
||||||
|
thread.AddFieldConfig("State", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
})
|
||||||
|
|
||||||
|
AddNodeInterfaceFields(thread)
|
||||||
|
})
|
@ -0,0 +1,114 @@
|
|||||||
|
package graphvent
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
)
|
||||||
|
|
||||||
|
var GQLMutationAbort = NewField(func()*graphql.Field {
|
||||||
|
gql_mutation_abort := &graphql.Field{
|
||||||
|
Type: GQLTypeSignal.Type,
|
||||||
|
Args: graphql.FieldConfigArgument{
|
||||||
|
"id": &graphql.ArgumentConfig{
|
||||||
|
Type: graphql.String,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
_, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := ExtractID(p, "id")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var node *Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLMap(
|
||||||
|
NewACLInfo(ctx.Server, []string{"children"}),
|
||||||
|
), func(context *StateContext) (error){
|
||||||
|
node, err = FindChild(context, ctx.User, ctx.Server, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
|
||||||
|
}
|
||||||
|
return SendSignal(context, node, ctx.User, AbortSignal)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return AbortSignal, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return gql_mutation_abort
|
||||||
|
})
|
||||||
|
|
||||||
|
var GQLMutationStartChild = NewField(func()*graphql.Field{
|
||||||
|
gql_mutation_start_child := &graphql.Field{
|
||||||
|
Type: GQLTypeSignal.Type,
|
||||||
|
Args: graphql.FieldConfigArgument{
|
||||||
|
"parent_id": &graphql.ArgumentConfig{
|
||||||
|
Type: graphql.String,
|
||||||
|
},
|
||||||
|
"child_id": &graphql.ArgumentConfig{
|
||||||
|
Type: graphql.String,
|
||||||
|
},
|
||||||
|
"action": &graphql.ArgumentConfig{
|
||||||
|
Type: graphql.String,
|
||||||
|
DefaultValue: "start",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
_, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
parent_id, err := ExtractID(p, "parent_id")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
child_id, err := ExtractID(p, "child_id")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
action, err := ExtractParam[string](p, "action")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var signal Signal
|
||||||
|
context := NewWriteContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLMap(
|
||||||
|
NewACLInfo(ctx.Server, []string{"children"}),
|
||||||
|
), func(context *StateContext) error {
|
||||||
|
parent, err := FindChild(context, ctx.User, ctx.Server, parent_id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if parent == nil {
|
||||||
|
return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
signal = NewStartChildSignal(child_id, action)
|
||||||
|
return SendSignal(context, ctx.User, parent, signal)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: wait for the result of the signal to send back instead of just the signal
|
||||||
|
return signal, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return gql_mutation_start_child
|
||||||
|
})
|
||||||
|
|
@ -1,147 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
"github.com/graphql-go/graphql/language/ast"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ok := p.Source.(NodeResult)
|
|
||||||
if ok == false {
|
|
||||||
return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source))
|
|
||||||
}
|
|
||||||
|
|
||||||
return node.NodeID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ok := p.Source.(NodeResult)
|
|
||||||
if ok == false {
|
|
||||||
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint64(node.NodeType), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type FieldIndex struct {
|
|
||||||
Extension ExtType
|
|
||||||
Tag string
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetFields(selection_set *ast.SelectionSet) []string {
|
|
||||||
names := []string{}
|
|
||||||
if selection_set == nil {
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sel := range(selection_set.Selections) {
|
|
||||||
switch field := sel.(type) {
|
|
||||||
case *ast.Field:
|
|
||||||
if field.Name.Value == "ID" || field.Name.Value == "Type" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
names = append(names, field.Name.Value)
|
|
||||||
case *ast.InlineFragment:
|
|
||||||
names = append(names, GetFields(field.SelectionSet)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the fields that need to be resolved
|
|
||||||
func GetResolveFields(p graphql.ResolveParams) []string {
|
|
||||||
fields := []string{}
|
|
||||||
for _, field := range(p.Info.FieldASTs) {
|
|
||||||
fields = append(fields, GetFields(field.SelectionSet)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
|
|
||||||
ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return NodeResult{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch source := p.Source.(type) {
|
|
||||||
case *StatusSignal:
|
|
||||||
cached_node, cached := ctx.NodeCache[source.Source]
|
|
||||||
if cached {
|
|
||||||
for _, field_name := range(source.Fields) {
|
|
||||||
_, cached := cached_node.Data[field_name]
|
|
||||||
if cached {
|
|
||||||
delete(cached_node.Data, field_name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx.NodeCache[source.Source] = cached_node
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cache, node_cached := ctx.NodeCache[id]
|
|
||||||
fields := GetResolveFields(p)
|
|
||||||
var not_cached []string
|
|
||||||
if node_cached {
|
|
||||||
not_cached = []string{}
|
|
||||||
for _, field := range(fields) {
|
|
||||||
if node_cached {
|
|
||||||
_, field_cached := cache.Data[field]
|
|
||||||
if field_cached {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
not_cached = append(not_cached, field)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
not_cached = fields
|
|
||||||
}
|
|
||||||
|
|
||||||
if (len(not_cached) == 0) && (node_cached == true) {
|
|
||||||
ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
|
|
||||||
return cache, nil
|
|
||||||
} else {
|
|
||||||
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
|
|
||||||
|
|
||||||
signal := NewReadSignal(not_cached)
|
|
||||||
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
|
||||||
// TODO: TIMEOUT DURATION
|
|
||||||
err = ctx.Context.Send(ctx.Server, []Message{{
|
|
||||||
Node: id,
|
|
||||||
Signal: signal,
|
|
||||||
}})
|
|
||||||
if err != nil {
|
|
||||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
|
||||||
return NodeResult{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
|
||||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
|
||||||
if err != nil {
|
|
||||||
return NodeResult{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch response := response.(type) {
|
|
||||||
case *ReadResultSignal:
|
|
||||||
if node_cached == false {
|
|
||||||
cache = NodeResult{
|
|
||||||
NodeID: id,
|
|
||||||
NodeType: response.NodeType,
|
|
||||||
Data: response.Fields,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for field_name, field_value := range(response.Fields) {
|
|
||||||
cache.Data[field_name] = field_value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.NodeCache[id] = cache
|
|
||||||
return ctx.NodeCache[id], nil
|
|
||||||
default:
|
|
||||||
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -0,0 +1,28 @@
|
|||||||
|
package graphvent
|
||||||
|
import (
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
)
|
||||||
|
|
||||||
|
var GQLQuerySelf = &graphql.Field{
|
||||||
|
Type: GQLTypeBaseThread.Type,
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
_, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.Server, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var GQLQueryUser = &graphql.Field{
|
||||||
|
Type: GQLTypeBaseNode.Type,
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
_, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.User, nil
|
||||||
|
},
|
||||||
|
}
|
@ -0,0 +1,338 @@
|
|||||||
|
package graphvent
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
|
||||||
|
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
||||||
|
if ok == false {
|
||||||
|
return nil, nil, fmt.Errorf("Bad resolve in params context")
|
||||||
|
}
|
||||||
|
|
||||||
|
node, ok := p.Source.(*Node)
|
||||||
|
if ok == false {
|
||||||
|
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, resolve_context, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
|
||||||
|
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
|
||||||
|
var zero K
|
||||||
|
arg_if, ok := p.Args[name]
|
||||||
|
if ok == false {
|
||||||
|
return zero, fmt.Errorf("No Arg of name %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg, ok := arg_if.(K)
|
||||||
|
if ok == false {
|
||||||
|
return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero))
|
||||||
|
}
|
||||||
|
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, error) {
|
||||||
|
var zero K
|
||||||
|
|
||||||
|
arg_list, err := ExtractParam[[]interface{}](p, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]K, len(arg_list))
|
||||||
|
for i, val := range(arg_list) {
|
||||||
|
val_conv, ok := arg_list[i].(K)
|
||||||
|
if ok == false {
|
||||||
|
return nil, fmt.Errorf("Failed to cast arg %s[%d](%+v) to %+v", name, i, val, reflect.TypeOf(zero))
|
||||||
|
}
|
||||||
|
ret[i] = val_conv
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
|
||||||
|
id_str, err := ExtractParam[string](p, name)
|
||||||
|
if err != nil {
|
||||||
|
return ZeroID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := ParseID(id_str)
|
||||||
|
if err != nil {
|
||||||
|
return ZeroID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
|
||||||
|
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, _, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, _, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(node.Type), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gql_ext, err := GetExt[*GQLExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listen := ""
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
|
||||||
|
listen = gql_ext.Listen
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return listen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var parent *Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
||||||
|
parent = thread_ext.Parent
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return parent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var state string
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
|
||||||
|
state = thread_ext.State
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var children []*Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
||||||
|
children = thread_ext.ChildList()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return children, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lockable_ext, err := GetExt[*LockableExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var requirements []*Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
||||||
|
requirements = make([]*Node, len(lockable_ext.Requirements))
|
||||||
|
i := 0
|
||||||
|
for _, req := range(lockable_ext.Requirements) {
|
||||||
|
requirements[i] = req
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return requirements, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lockable_ext, err := GetExt[*LockableExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var dependencies []*Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
|
||||||
|
dependencies = make([]*Node, len(lockable_ext.Dependencies))
|
||||||
|
i := 0
|
||||||
|
for _, dep := range(lockable_ext.Dependencies) {
|
||||||
|
dependencies[i] = dep
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dependencies, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lockable_ext, err := GetExt[*LockableExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var owner *Node = nil
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
|
||||||
|
owner = lockable_ext.Owner
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return owner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
group_ext, err := GetExt[*GroupExt](node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var members []*Node
|
||||||
|
context := NewReadContext(ctx.Context)
|
||||||
|
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
|
||||||
|
members = make([]*Node, len(group_ext.Members))
|
||||||
|
i := 0
|
||||||
|
for _, member := range(group_ext.Members) {
|
||||||
|
members[i] = member
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return members, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
||||||
|
if signal, ok := p.Source.(Signal); ok {
|
||||||
|
return fn(signal, p)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("Failed to cast source to event")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||||
|
return signal.Type(), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||||
|
direction := signal.Direction()
|
||||||
|
if direction == Up {
|
||||||
|
return "up", nil
|
||||||
|
} else if direction == Down {
|
||||||
|
return "down", nil
|
||||||
|
} else if direction == Direct {
|
||||||
|
return "direct", nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("Invalid direction: %+v", direction)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
||||||
|
ser, err := signal.Serialize()
|
||||||
|
return string(ser), err
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,69 @@
|
|||||||
|
package graphvent
|
||||||
|
import (
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return GQLSubscribeFn(p, false, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
||||||
|
return signal, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return GQLSubscribeFn(p, true, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
||||||
|
return server, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
||||||
|
_, ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := make(chan interface{})
|
||||||
|
go func(c chan interface{}, ext *GQLExt, server *Node) {
|
||||||
|
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
|
||||||
|
sig_c := ext.NewSubscriptionChannel(1)
|
||||||
|
if send_nil == true {
|
||||||
|
sig_c <- nil
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
val, ok := <- sig_c
|
||||||
|
if ok == false {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ret, err := fn(ctx.Context, server, ext, val, p)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Context.Log.Logf("gqlws", "type convertor error %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c <- ret
|
||||||
|
}
|
||||||
|
}(c, ctx.Ext, ctx.Server)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var GQLSubscriptionSelf = NewField(func()*graphql.Field{
|
||||||
|
gql_subscription_self := &graphql.Field{
|
||||||
|
Type: GQLTypeBaseThread.Type,
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return p.Source, nil
|
||||||
|
},
|
||||||
|
Subscribe: GQLSubscribeSelf,
|
||||||
|
}
|
||||||
|
|
||||||
|
return gql_subscription_self
|
||||||
|
})
|
||||||
|
|
||||||
|
var GQLSubscriptionUpdate = NewField(func()*graphql.Field{
|
||||||
|
gql_subscription_update := &graphql.Field{
|
||||||
|
Type: GQLTypeSignal.Type,
|
||||||
|
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
return p.Source, nil
|
||||||
|
},
|
||||||
|
Subscribe: GQLSubscribeSignal,
|
||||||
|
}
|
||||||
|
return gql_subscription_update
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,148 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AddNodeFields(obj *graphql.Object) {
|
||||||
|
obj.AddFieldConfig("ID", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLNodeID,
|
||||||
|
})
|
||||||
|
|
||||||
|
obj.AddFieldConfig("TypeHash", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLNodeTypeHash,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddLockableFields(obj *graphql.Object) {
|
||||||
|
AddNodeFields(obj)
|
||||||
|
|
||||||
|
obj.AddFieldConfig("Requirements", &graphql.Field{
|
||||||
|
Type: GQLInterfaceLockable.List,
|
||||||
|
Resolve: GQLLockableRequirements,
|
||||||
|
})
|
||||||
|
|
||||||
|
obj.AddFieldConfig("Owner", &graphql.Field{
|
||||||
|
Type: GQLInterfaceLockable.Type,
|
||||||
|
Resolve: GQLLockableOwner,
|
||||||
|
})
|
||||||
|
|
||||||
|
obj.AddFieldConfig("Dependencies", &graphql.Field{
|
||||||
|
Type: GQLInterfaceLockable.List,
|
||||||
|
Resolve: GQLLockableDependencies,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddThreadFields(obj *graphql.Object) {
|
||||||
|
AddNodeFields(obj)
|
||||||
|
|
||||||
|
obj.AddFieldConfig("State", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLThreadState,
|
||||||
|
})
|
||||||
|
|
||||||
|
obj.AddFieldConfig("Children", &graphql.Field{
|
||||||
|
Type: GQLInterfaceThread.List,
|
||||||
|
Resolve: GQLThreadChildren,
|
||||||
|
})
|
||||||
|
|
||||||
|
obj.AddFieldConfig("Parent", &graphql.Field{
|
||||||
|
Type: GQLInterfaceThread.Type,
|
||||||
|
Resolve: GQLThreadParent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
|
||||||
|
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
|
||||||
|
Name: "SimpleThread",
|
||||||
|
Interfaces: []*graphql.Interface{
|
||||||
|
GQLInterfaceNode.Type,
|
||||||
|
GQLInterfaceThread.Type,
|
||||||
|
GQLInterfaceLockable.Type,
|
||||||
|
},
|
||||||
|
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||||
|
node, ok := p.Value.(*Node)
|
||||||
|
if ok == false {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := GetExt[*ThreadExt](node)
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
AddThreadFields(gql_type_simple_thread)
|
||||||
|
|
||||||
|
return gql_type_simple_thread
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
|
||||||
|
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
|
||||||
|
Name: "SimpleLockable",
|
||||||
|
Interfaces: []*graphql.Interface{
|
||||||
|
GQLInterfaceNode.Type,
|
||||||
|
GQLInterfaceLockable.Type,
|
||||||
|
},
|
||||||
|
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||||
|
node, ok := p.Value.(*Node)
|
||||||
|
if ok == false {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := GetExt[*LockableExt](node)
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
AddLockableFields(gql_type_simple_lockable)
|
||||||
|
|
||||||
|
return gql_type_simple_lockable
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
|
||||||
|
object := graphql.NewObject(graphql.ObjectConfig{
|
||||||
|
Name: "SimpleNode",
|
||||||
|
Interfaces: []*graphql.Interface{
|
||||||
|
GQLInterfaceNode.Type,
|
||||||
|
},
|
||||||
|
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||||
|
_, ok := p.Value.(*Node)
|
||||||
|
return ok
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
AddNodeFields(object)
|
||||||
|
|
||||||
|
return object
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var GQLTypeSignal = NewSingleton(func() *graphql.Object {
|
||||||
|
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
|
||||||
|
Name: "Signal",
|
||||||
|
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
||||||
|
_, ok := p.Value.(Signal)
|
||||||
|
return ok
|
||||||
|
},
|
||||||
|
Fields: graphql.Fields{},
|
||||||
|
})
|
||||||
|
|
||||||
|
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLSignalType,
|
||||||
|
})
|
||||||
|
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLSignalDirection,
|
||||||
|
})
|
||||||
|
gql_type_signal.AddFieldConfig("String", &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Resolve: GQLSignalString,
|
||||||
|
})
|
||||||
|
return gql_type_signal
|
||||||
|
}, nil)
|
||||||
|
|
@ -1,66 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Listener extension provides a channel that can receive signals on a different thread
|
|
||||||
type ListenerExt struct {
|
|
||||||
Buffer int `gv:"buffer"`
|
|
||||||
Chan chan Signal
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoadedSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLoadedSignal() *LoadedSignal {
|
|
||||||
return &LoadedSignal{
|
|
||||||
SignalHeader: NewSignalHeader(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type UnloadedSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUnloadedSignal() *UnloadedSignal {
|
|
||||||
return &UnloadedSignal{
|
|
||||||
SignalHeader: NewSignalHeader(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
|
|
||||||
ext.Chan = make(chan Signal, ext.Buffer)
|
|
||||||
ext.Chan <- NewLoadedSignal()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
|
|
||||||
ext.Chan <- NewUnloadedSignal()
|
|
||||||
close(ext.Chan)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new listener extension with a given buffer size
|
|
||||||
func NewListenerExt(buffer int) *ListenerExt {
|
|
||||||
return &ListenerExt{
|
|
||||||
Buffer: buffer,
|
|
||||||
Chan: make(chan Signal, buffer),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the signal to the channel, logging an overflow if it occurs
|
|
||||||
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
|
||||||
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
|
|
||||||
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
|
|
||||||
select {
|
|
||||||
case ext.Chan <- signal:
|
|
||||||
default:
|
|
||||||
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
|
|
||||||
}
|
|
||||||
switch sig := signal.(type) {
|
|
||||||
case *StatusSignal:
|
|
||||||
ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
@ -1,411 +1,596 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/google/uuid"
|
"fmt"
|
||||||
|
"encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReqState byte
|
type ListenerExt struct {
|
||||||
const (
|
Buffer int
|
||||||
Unlocked = ReqState(0)
|
Chan chan Signal
|
||||||
Unlocking = ReqState(1)
|
}
|
||||||
Locked = ReqState(2)
|
|
||||||
Locking = ReqState(3)
|
|
||||||
AbortingLock = ReqState(4)
|
|
||||||
)
|
|
||||||
|
|
||||||
var ReqStateStrings = map[ReqState]string {
|
func NewListenerExt(buffer int) *ListenerExt {
|
||||||
Unlocked: "Unlocked",
|
return &ListenerExt{
|
||||||
Unlocking: "Unlocking",
|
Buffer: buffer,
|
||||||
Locked: "Locked",
|
Chan: make(chan Signal, buffer),
|
||||||
Locking: "Locking",
|
}
|
||||||
AbortingLock: "AbortingLock",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ReqState) String() string {
|
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
str, mapped := ReqStateStrings[state]
|
var j int
|
||||||
if mapped == false {
|
err := json.Unmarshal(data, &j)
|
||||||
return "UNKNOWN_REQSTATE"
|
if err != nil {
|
||||||
} else {
|
return nil, err
|
||||||
return str
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return NewListenerExt(j), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ListenerExtType = ExtType("LISTENER")
|
||||||
|
func (listener *ListenerExt) Type() ExtType {
|
||||||
|
return ListenerExtType
|
||||||
|
}
|
||||||
|
|
||||||
type LockableExt struct{
|
func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
State ReqState `gv:"state"`
|
select {
|
||||||
ReqID *uuid.UUID `gv:"req_id"`
|
case ext.Chan <- signal:
|
||||||
Owner *NodeID `gv:"owner"`
|
default:
|
||||||
PendingOwner *NodeID `gv:"pending_owner"`
|
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
|
||||||
Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
Locked map[NodeID]any
|
func (ext *ListenerExt) Serialize() ([]byte, error) {
|
||||||
Unlocked map[NodeID]any
|
return json.MarshalIndent(ext.Buffer, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
|
type LockableExt struct {
|
||||||
|
Owner *Node
|
||||||
|
Requirements map[NodeID]*Node
|
||||||
|
Dependencies map[NodeID]*Node
|
||||||
|
LocksHeld map[NodeID]*Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLockableExt(requirements []NodeID) *LockableExt {
|
const LockableExtType = ExtType("LOCKABLE")
|
||||||
var reqs map[NodeID]ReqState = nil
|
func (ext *LockableExt) Type() ExtType {
|
||||||
var unlocked map[NodeID]any = map[NodeID]any{}
|
return LockableExtType
|
||||||
|
}
|
||||||
|
|
||||||
if len(requirements) != 0 {
|
type LockableExtJSON struct {
|
||||||
reqs = map[NodeID]ReqState{}
|
Owner string `json:"owner"`
|
||||||
for _, req := range(requirements) {
|
Requirements []string `json:"requirements"`
|
||||||
reqs[req] = Unlocked
|
Dependencies []string `json:"dependencies"`
|
||||||
unlocked[req] = nil
|
LocksHeld map[string]string `json:"locks_held"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ext *LockableExt) Serialize() ([]byte, error) {
|
||||||
|
return json.MarshalIndent(&LockableExtJSON{
|
||||||
|
Owner: SaveNode(ext.Owner),
|
||||||
|
Requirements: SaveNodeList(ext.Requirements),
|
||||||
|
Dependencies: SaveNodeList(ext.Dependencies),
|
||||||
|
LocksHeld: SaveNodeMap(ext.LocksHeld),
|
||||||
|
}, "", " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LockableExt{
|
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
|
||||||
State: Unlocked,
|
if requirements == nil {
|
||||||
Owner: nil,
|
requirements = NodeMap{}
|
||||||
PendingOwner: nil,
|
}
|
||||||
Requirements: reqs,
|
|
||||||
Waiting: WaitMap{},
|
|
||||||
|
|
||||||
Locked: map[NodeID]any{},
|
if dependencies == nil {
|
||||||
Unlocked: unlocked,
|
dependencies = NodeMap{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if locks_held == nil {
|
||||||
|
locks_held = NodeMap{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
return &LockableExt{
|
||||||
signal := NewUnlockSignal()
|
Owner: owner,
|
||||||
messages := []Message{{node.ID, signal}}
|
Requirements: requirements,
|
||||||
return signal.ID(), ctx.Send(node, messages)
|
Dependencies: dependencies,
|
||||||
|
LocksHeld: locks_held,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
signal := NewLockSignal()
|
var j LockableExtJSON
|
||||||
messages := []Message{{node.ID, signal}}
|
err := json.Unmarshal(data, &j)
|
||||||
return signal.ID(), ctx.Send(node, messages)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) Load(ctx *Context, node *Node) error {
|
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j)
|
||||||
ext.Locked = map[NodeID]any{}
|
|
||||||
ext.Unlocked = map[NodeID]any{}
|
|
||||||
|
|
||||||
for id, state := range(ext.Requirements) {
|
owner, err := RestoreNode(ctx, j.Owner)
|
||||||
if state == Unlocked {
|
if err != nil {
|
||||||
ext.Unlocked[id] = nil
|
return nil, err
|
||||||
} else if state == Locked {
|
|
||||||
ext.Locked[id] = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requirements, err := RestoreNodeList(ctx, j.Requirements)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) Unload(ctx *Context, node *Node) {
|
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
|
||||||
return
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle link signal by adding/removing the requested NodeID
|
|
||||||
// returns an error if the node is not unlocked
|
|
||||||
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]Message, Changes) {
|
|
||||||
var messages []Message = nil
|
|
||||||
var changes Changes = nil
|
|
||||||
|
|
||||||
switch ext.State {
|
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
|
||||||
case Unlocked:
|
|
||||||
switch signal.Action {
|
|
||||||
case "add":
|
|
||||||
_, exists := ext.Requirements[signal.NodeID]
|
|
||||||
if exists == true {
|
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_requirement")})
|
|
||||||
} else {
|
|
||||||
if ext.Requirements == nil {
|
|
||||||
ext.Requirements = map[NodeID]ReqState{}
|
|
||||||
}
|
}
|
||||||
ext.Requirements[signal.NodeID] = Unlocked
|
|
||||||
changes = append(changes, "requirements")
|
func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch signal.Direction() {
|
||||||
|
case Up:
|
||||||
|
err = UseStates(context, node,
|
||||||
|
NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
|
||||||
|
owner_sent := false
|
||||||
|
for _, dependency := range(ext.Dependencies) {
|
||||||
|
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
|
||||||
|
SendSignal(context, dependency, node, signal)
|
||||||
|
if ext.Owner != nil {
|
||||||
|
if dependency.ID == ext.Owner.ID {
|
||||||
|
owner_sent = true
|
||||||
}
|
}
|
||||||
case "remove":
|
|
||||||
_, exists := ext.Requirements[signal.NodeID]
|
|
||||||
if exists == false {
|
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
|
|
||||||
} else {
|
|
||||||
delete(ext.Requirements, signal.NodeID)
|
|
||||||
changes = append(changes, "requirements")
|
|
||||||
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unknown_action")})
|
|
||||||
}
|
}
|
||||||
|
if ext.Owner != nil && owner_sent == false {
|
||||||
|
if ext.Owner.ID != node.ID {
|
||||||
|
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
|
||||||
|
return SendSignal(context, ext.Owner, node, signal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
case Down:
|
||||||
|
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
||||||
|
for _, requirement := range(ext.Requirements) {
|
||||||
|
err := SendSignal(context, requirement, node, signal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
case Direct:
|
||||||
|
err = nil
|
||||||
default:
|
default:
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocked: %s", ext.State)})
|
err = fmt.Errorf("invalid signal direction %d", signal.Direction())
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
return messages, changes
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle an UnlockSignal by either transitioning to Unlocked state,
|
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
|
||||||
// sending unlock signals to requirements, or returning an error signal
|
last_owner, exists := ext.LocksHeld[node.ID]
|
||||||
func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source NodeID, signal *UnlockSignal) ([]Message, Changes) {
|
if exists == false {
|
||||||
var messages []Message = nil
|
panic("Attempted to take a get the original lock holder of a lockable we don't own")
|
||||||
var changes Changes = nil
|
}
|
||||||
|
delete(ext.LocksHeld, node.ID)
|
||||||
|
return last_owner
|
||||||
|
}
|
||||||
|
|
||||||
switch ext.State {
|
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
|
||||||
case Locked:
|
_, exists := ext.LocksHeld[node.ID]
|
||||||
if source != *ext.Owner {
|
if exists == true {
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_owner")})
|
panic("Attempted to lock a lockable we're already holding(lock cycle)")
|
||||||
} else {
|
}
|
||||||
if len(ext.Requirements) == 0 {
|
ext.LocksHeld[node.ID] = last_owner
|
||||||
changes = append(changes, "state", "owner", "pending_owner")
|
}
|
||||||
|
|
||||||
ext.Owner = nil
|
// Removes requirement as a requirement from lockable
|
||||||
|
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
|
||||||
|
lockable_ext, err := GetExt[*LockableExt](lockable)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
requirement_ext, err := GetExt[*LockableExt](requirement)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return UpdateStates(context, princ, ACLMap{
|
||||||
|
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
|
||||||
|
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
|
||||||
|
}, func(context *StateContext) error {
|
||||||
|
var found *Node = nil
|
||||||
|
for _, req := range(lockable_ext.Requirements) {
|
||||||
|
if requirement.ID == req.ID {
|
||||||
|
found = req
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ext.PendingOwner = nil
|
if found == nil {
|
||||||
|
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
|
||||||
|
}
|
||||||
|
|
||||||
ext.State = Unlocked
|
delete(requirement_ext.Dependencies, lockable.ID)
|
||||||
|
delete(lockable_ext.Requirements, requirement.ID)
|
||||||
|
|
||||||
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
return nil
|
||||||
} else {
|
})
|
||||||
changes = append(changes, "state", "waiting", "requirements", "pending_owner")
|
}
|
||||||
|
|
||||||
ext.PendingOwner = nil
|
// Link requirements as requirements to lockable
|
||||||
|
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
|
||||||
|
if lockable == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
|
||||||
|
}
|
||||||
|
|
||||||
ext.ReqID = &signal.Id
|
if len(requirements) == 0 {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
|
||||||
|
}
|
||||||
|
|
||||||
ext.State = Unlocking
|
lockable_ext, err := GetExt[*LockableExt](lockable)
|
||||||
for id := range(ext.Requirements) {
|
if err != nil {
|
||||||
unlock_signal := NewUnlockSignal()
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ext.Waiting[unlock_signal.Id] = id
|
req_exts := map[NodeID]*LockableExt{}
|
||||||
ext.Requirements[id] = Unlocking
|
for _, requirement := range(requirements) {
|
||||||
|
if requirement == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
|
||||||
|
}
|
||||||
|
|
||||||
messages = append(messages, Message{id, unlock_signal})
|
if lockable.ID == requirement.ID {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, exists := req_exts[requirement.ID]
|
||||||
|
if exists == true {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
|
||||||
}
|
}
|
||||||
|
ext, err := GetExt[*LockableExt](requirement)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
default:
|
req_exts[requirement.ID] = ext
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_locked")})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, changes
|
return UpdateStates(context, princ, NewACLMap(
|
||||||
|
NewACLInfo(lockable, []string{"requirements"}),
|
||||||
|
ACLList(requirements, []string{"dependencies"}),
|
||||||
|
), func(context *StateContext) error {
|
||||||
|
// Check that all the requirements can be added
|
||||||
|
// If the lockable is already locked, need to lock this resource as well before we can add it
|
||||||
|
for _, requirement := range(requirements) {
|
||||||
|
requirement_ext := req_exts[requirement.ID]
|
||||||
|
for _, req := range(requirements) {
|
||||||
|
if req.ID == requirement.ID {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle a LockSignal by either transitioning to a locked state,
|
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
|
||||||
// sending lock signals to requirements, or returning an error signal
|
if err != nil {
|
||||||
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]Message, Changes) {
|
return err
|
||||||
var messages []Message = nil
|
} else if is_req {
|
||||||
var changes Changes = nil
|
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
|
||||||
|
|
||||||
switch ext.State {
|
}
|
||||||
case Unlocked:
|
}
|
||||||
if len(ext.Requirements) == 0 {
|
|
||||||
changes = append(changes, "state", "owner", "pending_owner")
|
|
||||||
|
|
||||||
ext.Owner = &source
|
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if is_req {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
|
||||||
|
}
|
||||||
|
|
||||||
ext.PendingOwner = &source
|
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if is_req {
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
|
||||||
|
}
|
||||||
|
|
||||||
ext.State = Locked
|
if lockable_ext.Owner == nil {
|
||||||
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
// If the new owner isn't locked, we can add the requirement
|
||||||
|
} else if requirement_ext.Owner == nil {
|
||||||
|
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
|
||||||
|
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
|
||||||
} else {
|
} else {
|
||||||
changes = append(changes, "state", "requirements", "waiting", "pending_owner")
|
// If the new requirement is already locked and the owner is already locked, their owners need to match
|
||||||
|
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
|
||||||
ext.PendingOwner = &source
|
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the states of the requirements
|
||||||
|
for _, requirement := range(requirements) {
|
||||||
|
requirement_ext := req_exts[requirement.ID]
|
||||||
|
requirement_ext.Dependencies[lockable.ID] = lockable
|
||||||
|
lockable_ext.Requirements[lockable.ID] = requirement
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
|
||||||
|
}
|
||||||
|
|
||||||
ext.ReqID = &signal.Id
|
// Return no error
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
ext.State = Locking
|
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
|
||||||
for id := range(ext.Requirements) {
|
for _, req := range(cur.Requirements) {
|
||||||
lock_signal := NewLockSignal()
|
if req.ID == id {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
ext.Waiting[lock_signal.Id] = id
|
req_ext, err := GetExt[*LockableExt](req)
|
||||||
ext.Requirements[id] = Locking
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
messages = append(messages, Message{id, lock_signal})
|
var is_req bool
|
||||||
|
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
|
||||||
|
is_req, err = checkIfRequirement(context, id, req_ext)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
if is_req == true {
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_unlocked: %s", ext.State)})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, changes
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle an error signal by aborting the lock, or retrying the unlock
|
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
|
||||||
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]Message, Changes) {
|
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
|
||||||
var messages []Message = nil
|
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
|
||||||
var changes Changes = nil
|
if to_lock == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
|
||||||
|
}
|
||||||
|
|
||||||
id, waiting := ext.Waiting[signal.ReqID]
|
req_exts := map[NodeID]*LockableExt{}
|
||||||
if waiting == true {
|
for _, l := range(to_lock) {
|
||||||
delete(ext.Waiting, signal.ReqID)
|
var err error
|
||||||
changes = append(changes, "waiting")
|
if l == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
|
||||||
|
}
|
||||||
|
|
||||||
switch ext.State {
|
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
||||||
case Locking:
|
if err != nil {
|
||||||
changes = append(changes, "state", "requirements")
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ext.Requirements[id] = Unlocked
|
if new_owner == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
|
||||||
|
}
|
||||||
|
|
||||||
unlocked := 0
|
new_owner_ext, err := GetExt[*LockableExt](new_owner)
|
||||||
for req_id, req_state := range(ext.Requirements) {
|
if err != nil {
|
||||||
// Unlock locked requirements, and count unlocked requirements
|
return err
|
||||||
switch req_state {
|
}
|
||||||
case Locked:
|
|
||||||
unlock_signal := NewUnlockSignal()
|
|
||||||
|
|
||||||
ext.Waiting[unlock_signal.Id] = req_id
|
// Called with no requirements to lock, success
|
||||||
ext.Requirements[req_id] = Unlocking
|
if len(to_lock) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
messages = append(messages, Message{req_id, unlock_signal})
|
return UpdateStates(context, new_owner, NewACLMap(
|
||||||
case Unlocked:
|
ACLListM(to_lock, []string{"lock"}),
|
||||||
unlocked += 1
|
NewACLInfo(new_owner, nil),
|
||||||
|
), func(context *StateContext) error {
|
||||||
|
// First loop is to check that the states can be locked, and locks all requirements
|
||||||
|
for _, req := range(to_lock) {
|
||||||
|
req_ext := req_exts[req.ID]
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
|
||||||
|
|
||||||
|
// If req is alreay locked, check that we can pass the lock
|
||||||
|
if req_ext.Owner != nil {
|
||||||
|
owner := req_ext.Owner
|
||||||
|
if owner.ID == new_owner.ID {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
|
||||||
|
return LockLockables(context, req_ext.Requirements, req)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if unlocked == len(ext.Requirements) {
|
|
||||||
changes = append(changes, "owner", "state")
|
|
||||||
ext.State = Unlocked
|
|
||||||
ext.Owner = nil
|
|
||||||
} else {
|
} else {
|
||||||
changes = append(changes, "state")
|
err := LockLockables(context, req_ext.Requirements, req)
|
||||||
ext.State = AbortingLock
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point state modification will be started, so no errors can be returned
|
||||||
|
for _, req := range(to_lock) {
|
||||||
|
req_ext := req_exts[req.ID]
|
||||||
|
old_owner := req_ext.Owner
|
||||||
|
// If the lockable was previously unowned, update the state
|
||||||
|
if old_owner == nil {
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
|
||||||
|
req_ext.Owner = new_owner
|
||||||
|
new_owner_ext.RecordLock(req, old_owner)
|
||||||
|
// Otherwise if the new owner already owns it, no need to update state
|
||||||
|
} else if old_owner.ID == new_owner.ID {
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
|
||||||
|
// Otherwise update the state
|
||||||
|
} else {
|
||||||
|
req_ext.Owner = new_owner
|
||||||
|
new_owner_ext.RecordLock(req, old_owner)
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
case Unlocking:
|
|
||||||
unlock_signal := NewUnlockSignal()
|
|
||||||
ext.Waiting[unlock_signal.Id] = id
|
|
||||||
messages = append(messages, Message{id, unlock_signal})
|
|
||||||
|
|
||||||
case AbortingLock:
|
|
||||||
req_state := ext.Requirements[id]
|
|
||||||
// Mark failed lock as Unlocked, or retry unlock
|
|
||||||
switch req_state {
|
|
||||||
case Locking:
|
|
||||||
ext.Requirements[id] = Unlocked
|
|
||||||
|
|
||||||
// Check if all requirements unlocked now
|
|
||||||
unlocked := 0
|
|
||||||
for _, req_state := range(ext.Requirements) {
|
|
||||||
if req_state == Unlocked {
|
|
||||||
unlocked += 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
|
||||||
|
if to_unlock == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if unlocked == len(ext.Requirements) {
|
req_exts := map[NodeID]*LockableExt{}
|
||||||
changes = append(changes, "owner", "state")
|
for _, l := range(to_unlock) {
|
||||||
ext.State = Unlocked
|
if l == nil {
|
||||||
ext.Owner = nil
|
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
|
||||||
}
|
}
|
||||||
case Unlocking:
|
|
||||||
// Handle error for unlocking requirement while unlocking by retrying unlock
|
var err error
|
||||||
unlock_signal := NewUnlockSignal()
|
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
||||||
ext.Waiting[unlock_signal.Id] = id
|
if err != nil {
|
||||||
messages = append(messages, Message{id, unlock_signal})
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if old_owner == nil {
|
||||||
|
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, changes
|
old_owner_ext, err := GetExt[*LockableExt](old_owner)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle a success signal by checking if all requirements have been locked/unlocked
|
|
||||||
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]Message, Changes) {
|
|
||||||
var messages []Message = nil
|
|
||||||
var changes Changes = nil
|
|
||||||
|
|
||||||
id, waiting := ext.Waiting[signal.ReqID]
|
// Called with no requirements to unlock, success
|
||||||
if waiting == true {
|
if len(to_unlock) == 0 {
|
||||||
delete(ext.Waiting, signal.ReqID)
|
return nil
|
||||||
changes = append(changes, "waiting")
|
}
|
||||||
|
|
||||||
switch ext.State {
|
return UpdateStates(context, old_owner, NewACLMap(
|
||||||
case Locking:
|
ACLListM(to_unlock, []string{"lock"}),
|
||||||
ext.Requirements[id] = Locked
|
NewACLInfo(old_owner, nil),
|
||||||
ext.Locked[id] = nil
|
), func(context *StateContext) error {
|
||||||
delete(ext.Unlocked, id)
|
// First loop is to check that the states can be locked, and locks all requirements
|
||||||
|
for _, req := range(to_unlock) {
|
||||||
|
req_ext := req_exts[req.ID]
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
|
||||||
|
|
||||||
if len(ext.Locked) == len(ext.Requirements) {
|
// Check if the owner is correct
|
||||||
ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked))
|
if req_ext.Owner != nil {
|
||||||
changes = append(changes, "state", "owner", "req_id")
|
if req_ext.Owner.ID != old_owner.ID {
|
||||||
ext.State = Locked
|
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
|
||||||
|
}
|
||||||
|
|
||||||
ext.Owner = ext.PendingOwner
|
err := UnlockLockables(context, req_ext.Requirements, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
// At this point state modification will be started, so no errors can be returned
|
||||||
ext.ReqID = nil
|
for _, req := range(to_unlock) {
|
||||||
|
req_ext := req_exts[req.ID]
|
||||||
|
new_owner := old_owner_ext.RecordUnlock(req)
|
||||||
|
req_ext.Owner = new_owner
|
||||||
|
if new_owner == nil {
|
||||||
|
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
|
||||||
} else {
|
} else {
|
||||||
ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements))
|
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveNode(node *Node) string {
|
||||||
|
str := ""
|
||||||
|
if node != nil {
|
||||||
|
str = node.ID.String()
|
||||||
|
}
|
||||||
|
return str
|
||||||
}
|
}
|
||||||
case AbortingLock:
|
|
||||||
req_state := ext.Requirements[id]
|
|
||||||
switch req_state {
|
|
||||||
case Locking:
|
|
||||||
ext.Requirements[id] = Unlocking
|
|
||||||
unlock_signal := NewUnlockSignal()
|
|
||||||
ext.Waiting[unlock_signal.Id] = id
|
|
||||||
messages = append(messages, Message{id, unlock_signal})
|
|
||||||
case Unlocking:
|
|
||||||
ext.Requirements[id] = Unlocked
|
|
||||||
ext.Unlocked[id] = nil
|
|
||||||
delete(ext.Locked, id)
|
|
||||||
|
|
||||||
unlocked := 0
|
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
|
||||||
for _, req_state := range(ext.Requirements) {
|
if id_str == "" {
|
||||||
switch req_state {
|
return nil, nil
|
||||||
case Unlocked:
|
|
||||||
unlocked += 1
|
|
||||||
}
|
}
|
||||||
|
id, err := ParseID(id_str)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if unlocked == len(ext.Requirements) {
|
return LoadNode(ctx, id)
|
||||||
changes = append(changes, "state", "pending_owner", "req_id")
|
}
|
||||||
|
|
||||||
messages = append(messages, Message{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)})
|
func SaveNodeMap(nodes NodeMap) map[string]string {
|
||||||
ext.State = Unlocked
|
m := map[string]string{}
|
||||||
ext.ReqID = nil
|
for id, node := range(nodes) {
|
||||||
ext.PendingOwner = nil
|
m[id.String()] = SaveNode(node)
|
||||||
}
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
|
||||||
|
nodes := NodeMap{}
|
||||||
|
for id_str_1, id_str_2 := range(ids) {
|
||||||
|
id_1, err := ParseID(id_str_1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
case Unlocking:
|
node_1, err := LoadNode(ctx, id_1)
|
||||||
ext.Requirements[id] = Unlocked
|
if err != nil {
|
||||||
ext.Unlocked[id] = Unlocked
|
return nil, err
|
||||||
delete(ext.Locked, id)
|
}
|
||||||
|
|
||||||
if len(ext.Unlocked) == len(ext.Requirements) {
|
|
||||||
changes = append(changes, "state", "owner", "req_id")
|
|
||||||
|
|
||||||
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
var node_2 *Node = nil
|
||||||
ext.State = Unlocked
|
if id_str_2 != "" {
|
||||||
ext.ReqID = nil
|
id_2, err := ParseID(id_str_2)
|
||||||
ext.Owner = nil
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
node_2, err = LoadNode(ctx, id_2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, changes
|
nodes[node_1.ID] = node_2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
return nodes, nil
|
||||||
var messages []Message = nil
|
}
|
||||||
var changes Changes = nil
|
|
||||||
|
|
||||||
switch sig := signal.(type) {
|
func SaveNodeList(nodes NodeMap) []string {
|
||||||
case *StatusSignal:
|
ids := make([]string, len(nodes))
|
||||||
// Forward StatusSignals up to the owner(unless that would be a cycle)
|
i := 0
|
||||||
if ext.Owner != nil {
|
for id, _ := range(nodes) {
|
||||||
if *ext.Owner != node.ID {
|
ids[i] = id.String()
|
||||||
messages = append(messages, Message{*ext.Owner, signal})
|
i += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
|
||||||
|
nodes := NodeMap{}
|
||||||
|
|
||||||
|
for _, id_str := range(ids) {
|
||||||
|
node, err := RestoreNode(ctx, id_str)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
case *LinkSignal:
|
nodes[node.ID] = node
|
||||||
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
|
|
||||||
case *LockSignal:
|
|
||||||
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
|
|
||||||
case *UnlockSignal:
|
|
||||||
messages, changes = ext.HandleUnlockSignal(ctx, node, source, sig)
|
|
||||||
case *ErrorSignal:
|
|
||||||
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
|
|
||||||
case *SuccessSignal:
|
|
||||||
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, changes
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,148 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLink(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"lockable", "listener"})
|
|
||||||
|
|
||||||
|
|
||||||
l2_listener := NewListenerExt(10)
|
|
||||||
l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
l1_lockable := NewLockableExt(nil)
|
|
||||||
l1_listener := NewListenerExt(10)
|
|
||||||
l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
link_signal := NewLinkSignal("add", l2.ID)
|
|
||||||
msgs := []Message{{l1.ID, link_signal}}
|
|
||||||
err = ctx.Send(l1, msgs)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
state, exists := l1_lockable.Requirements[l2.ID]
|
|
||||||
if exists == false {
|
|
||||||
t.Fatal("l2 not in l1 requirements")
|
|
||||||
} else if state != Unlocked {
|
|
||||||
t.Fatalf("l2 in bad requirement state in l1: %+v", state)
|
|
||||||
}
|
|
||||||
|
|
||||||
unlink_signal := NewLinkSignal("remove", l2.ID)
|
|
||||||
msgs = []Message{{l1.ID, unlink_signal}}
|
|
||||||
err = ctx.Send(l1, msgs)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
|
|
||||||
fatalErr(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test10Lock(t *testing.T) {
|
|
||||||
testLockN(t, 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test100Lock(t *testing.T) {
|
|
||||||
testLockN(t, 100)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test1000Lock(t *testing.T) {
|
|
||||||
testLockN(t, 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test10000Lock(t *testing.T) {
|
|
||||||
testLockN(t, 10000)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testLockN(t *testing.T, n int) {
|
|
||||||
ctx := logTestContext(t, []string{"test"})
|
|
||||||
|
|
||||||
NewLockable := func()(*Node) {
|
|
||||||
l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
|
|
||||||
fatalErr(t, err)
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
reqs := make([]NodeID, n)
|
|
||||||
for i := range(reqs) {
|
|
||||||
new_lockable := NewLockable()
|
|
||||||
reqs[i] = new_lockable.ID
|
|
||||||
}
|
|
||||||
ctx.Log.Logf("test", "CREATED_%d", n)
|
|
||||||
|
|
||||||
listener := NewListenerExt(50000)
|
|
||||||
node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "CREATED_LISTENER")
|
|
||||||
|
|
||||||
lock_id, err := LockLockable(ctx, node)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
response, _, err := WaitForResponse(listener.Chan, time.Second*60, lock_id)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
switch resp := response.(type) {
|
|
||||||
case *SuccessSignal:
|
|
||||||
default:
|
|
||||||
t.Fatalf("Unexpected response to lock - %s", resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "LOCKED_%d", n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLock(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test", "lockable"})
|
|
||||||
|
|
||||||
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
|
|
||||||
listener := NewListenerExt(10000)
|
|
||||||
l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
|
||||||
fatalErr(t, err)
|
|
||||||
return l, listener
|
|
||||||
}
|
|
||||||
|
|
||||||
l2, _ := NewLockable(nil)
|
|
||||||
l3, _ := NewLockable(nil)
|
|
||||||
l4, _ := NewLockable(nil)
|
|
||||||
l5, _ := NewLockable(nil)
|
|
||||||
l0, l0_listener := NewLockable([]NodeID{l5.ID})
|
|
||||||
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "l0: %s", l0.ID)
|
|
||||||
ctx.Log.Logf("test", "l1: %s", l1.ID)
|
|
||||||
ctx.Log.Logf("test", "l2: %s", l2.ID)
|
|
||||||
ctx.Log.Logf("test", "l3: %s", l3.ID)
|
|
||||||
ctx.Log.Logf("test", "l4: %s", l4.ID)
|
|
||||||
ctx.Log.Logf("test", "l5: %s", l5.ID)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "locking l0")
|
|
||||||
id_1, err := LockLockable(ctx, l0)
|
|
||||||
fatalErr(t, err)
|
|
||||||
response, _, err := WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1)
|
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "l0 lock: %+v", response)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "locking l1")
|
|
||||||
id_2, err := LockLockable(ctx, l1)
|
|
||||||
fatalErr(t, err)
|
|
||||||
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2)
|
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "unlocking l0")
|
|
||||||
id_3, err := UnlockLockable(ctx, l0)
|
|
||||||
fatalErr(t, err)
|
|
||||||
response, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3)
|
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "l0 unlock: %+v", response)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "locking l1")
|
|
||||||
id_4, err := LockLockable(ctx, l1)
|
|
||||||
fatalErr(t, err)
|
|
||||||
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
|
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
|
||||||
}
|
|
@ -1,68 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Node NodeID
|
|
||||||
Signal Signal
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageQueue struct {
|
|
||||||
out chan<- Message
|
|
||||||
in <-chan Message
|
|
||||||
buffer []Message
|
|
||||||
write_cursor int
|
|
||||||
read_cursor int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (queue *MessageQueue) ProcessIncoming(message Message) {
|
|
||||||
if (queue.write_cursor + 1) == queue.read_cursor || ((queue.write_cursor + 1) == len(queue.buffer) && queue.read_cursor == 0) {
|
|
||||||
new_buffer := make([]Message, len(queue.buffer) * 2)
|
|
||||||
|
|
||||||
copy(new_buffer, queue.buffer[queue.read_cursor:])
|
|
||||||
first_chunk := len(queue.buffer) - queue.read_cursor
|
|
||||||
copy(new_buffer[first_chunk:], queue.buffer[0:queue.write_cursor])
|
|
||||||
|
|
||||||
queue.write_cursor = len(queue.buffer) - 1
|
|
||||||
queue.read_cursor = 0
|
|
||||||
queue.buffer = new_buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
queue.buffer[queue.write_cursor] = message
|
|
||||||
queue.write_cursor += 1
|
|
||||||
if queue.write_cursor >= len(queue.buffer) {
|
|
||||||
queue.write_cursor = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMessageQueue(initial int) (chan<- Message, <-chan Message) {
|
|
||||||
in := make(chan Message, 0)
|
|
||||||
out := make(chan Message, 0)
|
|
||||||
|
|
||||||
queue := MessageQueue{
|
|
||||||
out: out,
|
|
||||||
in: in,
|
|
||||||
buffer: make([]Message, initial),
|
|
||||||
write_cursor: 0,
|
|
||||||
read_cursor: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(queue *MessageQueue) {
|
|
||||||
for true {
|
|
||||||
if queue.write_cursor != queue.read_cursor {
|
|
||||||
select {
|
|
||||||
case incoming := <-queue.in:
|
|
||||||
queue.ProcessIncoming(incoming)
|
|
||||||
case queue.out <- queue.buffer[queue.read_cursor]:
|
|
||||||
queue.read_cursor += 1
|
|
||||||
if queue.read_cursor >= len(queue.buffer) {
|
|
||||||
queue.read_cursor = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
message := <-queue.in
|
|
||||||
queue.ProcessIncoming(message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(&queue)
|
|
||||||
|
|
||||||
return in, out
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func sendBatch(start, end uint64, in chan<- Message) {
|
|
||||||
for i := start; i <= end; i++ {
|
|
||||||
var id NodeID
|
|
||||||
binary.BigEndian.PutUint64(id[:], i)
|
|
||||||
in <- Message{id, nil}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageQueue(t *testing.T) {
|
|
||||||
in, out := NewMessageQueue(10)
|
|
||||||
|
|
||||||
for i := uint64(0); i < 1000; i++ {
|
|
||||||
go sendBatch(1000*i, (1000*(i+1))-1, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
seen := map[NodeID]any{}
|
|
||||||
for i := uint64(0); i < 1000*1000; i++ {
|
|
||||||
read := <-out
|
|
||||||
_, already_seen := seen[read.Node]
|
|
||||||
if already_seen {
|
|
||||||
t.Fatalf("Signal %d had duplicate NodeID %s", i, read.Node)
|
|
||||||
} else {
|
|
||||||
seen[read.Node] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("Processed 1M signals through queue")
|
|
||||||
}
|
|
@ -0,0 +1,396 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Policy interface {
|
||||||
|
Serializable[PolicyType]
|
||||||
|
Allows(context *StateContext, principal *Node, action string, node *Node) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
|
||||||
|
type RequirementOfPolicy struct {
|
||||||
|
PerNodePolicy
|
||||||
|
}
|
||||||
|
func (policy *RequirementOfPolicy) Type() PolicyType {
|
||||||
|
return RequirementOfPolicyType
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy {
|
||||||
|
return RequirementOfPolicy{
|
||||||
|
PerNodePolicy: NewPerNodePolicy(nodes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any of principals dependencies are in the policy
|
||||||
|
func (policy *RequirementOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||||
|
lockable_ext, err := GetExt[*LockableExt](principal)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for dep_id, _ := range(lockable_ext.Dependencies) {
|
||||||
|
for node_id, actions := range(policy.NodeActions) {
|
||||||
|
if node_id == dep_id {
|
||||||
|
if actions.Allows(action) == true {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const ChildOfPolicyType = PolicyType("CHILD_OF")
|
||||||
|
type ChildOfPolicy struct {
|
||||||
|
PerNodePolicy
|
||||||
|
}
|
||||||
|
func (policy *ChildOfPolicy) Type() PolicyType {
|
||||||
|
return ChildOfPolicyType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||||
|
context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy)
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](principal)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
parent := thread_ext.Parent
|
||||||
|
if parent != nil {
|
||||||
|
actions, exists := policy.NodeActions[parent.ID]
|
||||||
|
if exists == false {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, a := range(actions) {
|
||||||
|
if a == action {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type Actions []string
|
||||||
|
|
||||||
|
func (actions Actions) Allows(action string) bool {
|
||||||
|
for _, a := range(actions) {
|
||||||
|
if a == action {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type NodeActions map[NodeID]Actions
|
||||||
|
|
||||||
|
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
|
||||||
|
return func(ctx *Context, data []byte)(Policy, error){
|
||||||
|
var j PerNodePolicyJSON
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
node_actions := NodeActions{}
|
||||||
|
for id_str, actions := range(j.NodeActions) {
|
||||||
|
id, err := ParseID(id_str)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = LoadNode(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
node_actions[id] = actions
|
||||||
|
}
|
||||||
|
|
||||||
|
return init_fn(node_actions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChildOfPolicy(node_actions NodeActions) ChildOfPolicy {
|
||||||
|
return ChildOfPolicy{
|
||||||
|
PerNodePolicy: NewPerNodePolicy(node_actions),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ParentOfPolicyType = PolicyType("PARENT_OF")
|
||||||
|
type ParentOfPolicy struct {
|
||||||
|
PerNodePolicy
|
||||||
|
}
|
||||||
|
func (policy *ParentOfPolicy) Type() PolicyType {
|
||||||
|
return ParentOfPolicyType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||||
|
context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy)
|
||||||
|
for id, actions := range(policy.NodeActions) {
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent)
|
||||||
|
if thread_ext.Parent != nil {
|
||||||
|
if thread_ext.Parent.ID == principal.ID {
|
||||||
|
for _, a := range(actions) {
|
||||||
|
if a == action {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy {
|
||||||
|
return ParentOfPolicy{
|
||||||
|
PerNodePolicy: NewPerNodePolicy(node_actions),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
|
||||||
|
if node_actions == nil {
|
||||||
|
node_actions = NodeActions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return PerNodePolicy{
|
||||||
|
NodeActions: node_actions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerNodePolicy struct {
|
||||||
|
NodeActions NodeActions
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerNodePolicyJSON struct {
|
||||||
|
NodeActions map[string][]string `json:"node_actions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const PerNodePolicyType = PolicyType("PER_NODE")
|
||||||
|
func (policy *PerNodePolicy) Type() PolicyType {
|
||||||
|
return PerNodePolicyType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
|
||||||
|
node_actions := map[string][]string{}
|
||||||
|
for id, actions := range(policy.NodeActions) {
|
||||||
|
node_actions[id.String()] = actions
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.MarshalIndent(&PerNodePolicyJSON{
|
||||||
|
NodeActions: node_actions,
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||||
|
for id, actions := range(policy.NodeActions) {
|
||||||
|
if id != principal.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, a := range(actions) {
|
||||||
|
if a == action {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Extension to allow a node to hold ACL policies
|
||||||
|
type ACLPolicyExt struct {
|
||||||
|
Policies map[PolicyType]Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
|
||||||
|
type ACLExt struct {
|
||||||
|
Delegations NodeMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
|
var j struct {
|
||||||
|
Delegations []string `json:"delegation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
delegations, err := RestoreNodeList(ctx, j.Delegations)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ACLExt{
|
||||||
|
Delegations: delegations,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeList(nodes ...*Node) NodeMap {
|
||||||
|
m := NodeMap{}
|
||||||
|
for _, node := range(nodes) {
|
||||||
|
m[node.ID] = node
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLExt(delegations NodeMap) *ACLExt {
|
||||||
|
if delegations == nil {
|
||||||
|
delegations = NodeMap{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ACLExt{
|
||||||
|
Delegations: delegations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLExt) Serialize() ([]byte, error) {
|
||||||
|
delegations := make([]string, len(ext.Delegations))
|
||||||
|
i := 0
|
||||||
|
for id, _ := range(ext.Delegations) {
|
||||||
|
delegations[i] = id.String()
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.MarshalIndent(&struct{
|
||||||
|
Delegations []string `json:"delegations"`
|
||||||
|
}{
|
||||||
|
Delegations: delegations,
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
const ACLExtType = ExtType("ACL")
|
||||||
|
func (ext *ACLExt) Type() ExtType {
|
||||||
|
return ACLExtType
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
|
||||||
|
type PolicyInfo struct {
|
||||||
|
Load PolicyLoadFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACLPolicyExtContext struct {
|
||||||
|
Types map[PolicyType]PolicyInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLPolicyExtContext() *ACLPolicyExtContext {
|
||||||
|
return &ACLPolicyExtContext{
|
||||||
|
Types: map[PolicyType]PolicyInfo{
|
||||||
|
PerNodePolicyType: PolicyInfo{
|
||||||
|
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||||
|
policy := NewPerNodePolicy(nodes)
|
||||||
|
return &policy, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ParentOfPolicyType: PolicyInfo{
|
||||||
|
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||||
|
policy := NewParentOfPolicy(nodes)
|
||||||
|
return &policy, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ChildOfPolicyType: PolicyInfo{
|
||||||
|
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||||
|
policy := NewChildOfPolicy(nodes)
|
||||||
|
return &policy, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
RequirementOfPolicyType: PolicyInfo{
|
||||||
|
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
||||||
|
policy := NewRequirementOfPolicy(nodes)
|
||||||
|
return &policy, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLPolicyExt) Serialize() ([]byte, error) {
|
||||||
|
policies := map[string][]byte{}
|
||||||
|
for name, policy := range(ext.Policies) {
|
||||||
|
ser, err := policy.Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
policies[string(name)] = ser
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.MarshalIndent(&struct{
|
||||||
|
Policies map[string][]byte `json:"policies"`
|
||||||
|
}{
|
||||||
|
Policies: policies,
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt {
|
||||||
|
if policies == nil {
|
||||||
|
policies = map[PolicyType]Policy{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ACLPolicyExt{
|
||||||
|
Policies: policies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
|
var j struct {
|
||||||
|
Policies map[string][]byte `json:"policies"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := map[PolicyType]Policy{}
|
||||||
|
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(*ACLPolicyExtContext)
|
||||||
|
for name, ser := range(j.Policies) {
|
||||||
|
policy_def, exists := acl_ctx.Types[PolicyType(name)]
|
||||||
|
if exists == false {
|
||||||
|
return nil, fmt.Errorf("%s is not a known policy type", name)
|
||||||
|
}
|
||||||
|
policy, err := policy_def.Load(ctx, ser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
policies[PolicyType(name)] = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewACLPolicyExt(policies), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const ACLPolicyExtType = ExtType("ACL_POLICIES")
|
||||||
|
func (ext *ACLPolicyExt) Type() ExtType {
|
||||||
|
return ACLPolicyExtType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the extension allows the principal to perform action on node
|
||||||
|
func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
||||||
|
context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
|
||||||
|
for _, policy := range(ext.Policies) {
|
||||||
|
context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy)
|
||||||
|
if policy.Allows(context, principal, action, node) == true {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
@ -1,744 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"math"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SerializedType uint64
|
|
||||||
|
|
||||||
func (t SerializedType) String() string {
|
|
||||||
return fmt.Sprintf("0x%x", uint64(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExtType SerializedType
|
|
||||||
|
|
||||||
func (t ExtType) String() string {
|
|
||||||
return fmt.Sprintf("0x%x", uint64(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
type NodeType SerializedType
|
|
||||||
|
|
||||||
func (t NodeType) String() string {
|
|
||||||
return fmt.Sprintf("0x%x", uint64(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
type SignalType SerializedType
|
|
||||||
|
|
||||||
func (t SignalType) String() string {
|
|
||||||
return fmt.Sprintf("0x%x", uint64(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
type FieldTag SerializedType
|
|
||||||
|
|
||||||
func (t FieldTag) String() string {
|
|
||||||
return fmt.Sprintf("0x%x", uint64(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodeTypeFor(name string) NodeType {
|
|
||||||
digest := []byte("GRAPHVENT_NODE - " + name)
|
|
||||||
|
|
||||||
hash := sha512.Sum512(digest)
|
|
||||||
return NodeType(binary.BigEndian.Uint64(hash[0:8]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func SerializeType(t fmt.Stringer) SerializedType {
|
|
||||||
digest := []byte(t.String())
|
|
||||||
hash := sha512.Sum512(digest)
|
|
||||||
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func SerializedTypeFor[T any]() SerializedType {
|
|
||||||
return SerializeType(reflect.TypeFor[T]())
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType {
|
|
||||||
return ExtType(SerializedTypeFor[E]())
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExtTypeOf(t reflect.Type) ExtType {
|
|
||||||
return ExtType(SerializeType(t.Elem()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignalTypeFor[S Signal]() SignalType {
|
|
||||||
return SignalType(SerializedTypeFor[S]())
|
|
||||||
}
|
|
||||||
|
|
||||||
func Hash(base, data string) SerializedType {
|
|
||||||
digest := []byte(base + ":" + data)
|
|
||||||
hash := sha512.Sum512(digest)
|
|
||||||
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetFieldTag(tag string) FieldTag {
|
|
||||||
return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TypeStack(ctx *Context, t reflect.Type, data []byte) (int, error) {
|
|
||||||
info, registered := ctx.Types[t]
|
|
||||||
if registered {
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(info.Serialized))
|
|
||||||
return 8, nil
|
|
||||||
} else {
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Map:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Map)))
|
|
||||||
|
|
||||||
key_written, err := TypeStack(ctx, t.Key(), data[8:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8 + key_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return 8 + key_written + elem_written, nil
|
|
||||||
|
|
||||||
case reflect.Pointer:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Pointer)))
|
|
||||||
|
|
||||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return 8 + elem_written, nil
|
|
||||||
case reflect.Slice:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Slice)))
|
|
||||||
|
|
||||||
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return 8 + elem_written, nil
|
|
||||||
case reflect.Array:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Array)))
|
|
||||||
binary.BigEndian.PutUint64(data[8:], uint64(t.Len()))
|
|
||||||
|
|
||||||
elem_written, err := TypeStack(ctx, t.Elem(), data[16:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return 16 + elem_written, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("Hit %s, which is not a registered type", t.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
|
|
||||||
first_bytes, left := split(stack, 8)
|
|
||||||
first := SerializedType(binary.BigEndian.Uint64(first_bytes))
|
|
||||||
|
|
||||||
info, registered := ctx.TypesReverse[first]
|
|
||||||
if registered {
|
|
||||||
return info.Reflect, left, nil
|
|
||||||
} else {
|
|
||||||
switch first {
|
|
||||||
case SerializeType(reflect.Map):
|
|
||||||
key_type, after_key, err := UnwrapStack(ctx, left)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
elem_type, after_elem, err := UnwrapStack(ctx, after_key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.MapOf(key_type, elem_type), after_elem, nil
|
|
||||||
case SerializeType(reflect.Pointer):
|
|
||||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return reflect.PointerTo(elem_type), rest, nil
|
|
||||||
case SerializeType(reflect.Slice):
|
|
||||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return reflect.SliceOf(elem_type), rest, nil
|
|
||||||
case SerializeType(reflect.Array):
|
|
||||||
length_bytes, left := split(left, 8)
|
|
||||||
length := int(binary.BigEndian.Uint64(length_bytes))
|
|
||||||
|
|
||||||
elem_type, rest, err := UnwrapStack(ctx, left)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.ArrayOf(length, elem_type), rest, nil
|
|
||||||
default:
|
|
||||||
return nil, nil, fmt.Errorf("Type stack %+v not recognized", stack)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Serialize[T any](ctx *Context, value T, data []byte) (int, error) {
|
|
||||||
return SerializeValue(ctx, reflect.ValueOf(&value).Elem(), data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Deserialize[T any](ctx *Context, data []byte) (T, error) {
|
|
||||||
reflect_type := reflect.TypeFor[T]()
|
|
||||||
var zero T
|
|
||||||
value, left, err := DeserializeValue(ctx, data, reflect_type)
|
|
||||||
if err != nil {
|
|
||||||
return zero, err
|
|
||||||
} else if len(left) != 0 {
|
|
||||||
return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
|
|
||||||
} else if value.Type() != reflect_type {
|
|
||||||
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return value.Interface().(T), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SerializedSize(ctx *Context, value reflect.Value) (int, error) {
|
|
||||||
var sizefn SerializedSizeFn = nil
|
|
||||||
|
|
||||||
info, registered := ctx.Types[value.Type()]
|
|
||||||
if registered {
|
|
||||||
sizefn = info.SerializedSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if sizefn == nil {
|
|
||||||
switch value.Type().Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return 1, nil
|
|
||||||
|
|
||||||
case reflect.Int8:
|
|
||||||
return 1, nil
|
|
||||||
case reflect.Int16:
|
|
||||||
return 2, nil
|
|
||||||
case reflect.Int32:
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Int64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Int:
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.Uint8:
|
|
||||||
return 1, nil
|
|
||||||
case reflect.Uint16:
|
|
||||||
return 2, nil
|
|
||||||
case reflect.Uint32:
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Uint64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Uint:
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.Float32:
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Float64:
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.String:
|
|
||||||
return 8 + value.Len(), nil
|
|
||||||
|
|
||||||
case reflect.Pointer:
|
|
||||||
if value.IsNil() {
|
|
||||||
return 1, nil
|
|
||||||
} else {
|
|
||||||
elem_len, err := SerializedSize(ctx, value.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return 1 + elem_len, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Slice:
|
|
||||||
if value.IsNil() {
|
|
||||||
return 1, nil
|
|
||||||
} else {
|
|
||||||
elem_total := 0
|
|
||||||
for i := 0; i < value.Len(); i++ {
|
|
||||||
elem_len, err := SerializedSize(ctx, value.Index(i))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
elem_total += elem_len
|
|
||||||
}
|
|
||||||
return 9 + elem_total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Array:
|
|
||||||
total := 0
|
|
||||||
for i := 0; i < value.Len(); i++ {
|
|
||||||
elem_len, err := SerializedSize(ctx, value.Index(i))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total += elem_len
|
|
||||||
}
|
|
||||||
return total, nil
|
|
||||||
|
|
||||||
case reflect.Map:
|
|
||||||
if value.IsNil() {
|
|
||||||
return 1, nil
|
|
||||||
} else {
|
|
||||||
key := reflect.New(value.Type().Key()).Elem()
|
|
||||||
val := reflect.New(value.Type().Elem()).Elem()
|
|
||||||
iter := value.MapRange()
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
for iter.Next() {
|
|
||||||
key.SetIterKey(iter)
|
|
||||||
k, err := SerializedSize(ctx, key)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total += k
|
|
||||||
|
|
||||||
val.SetIterValue(iter)
|
|
||||||
v, err := SerializedSize(ctx, val)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total += v
|
|
||||||
}
|
|
||||||
|
|
||||||
return 9 + total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
if registered == false {
|
|
||||||
return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type())
|
|
||||||
} else {
|
|
||||||
field_total := 0
|
|
||||||
for _, field_info := range(info.Fields) {
|
|
||||||
field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
field_total += 8
|
|
||||||
field_total += field_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return 8 + field_total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Interface:
|
|
||||||
// TODO get size of TypeStack instead of just using 128
|
|
||||||
elem_size, err := SerializedSize(ctx, value.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return 128 + elem_size, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return sizefn(ctx, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) {
|
|
||||||
var serialize SerializeFn = nil
|
|
||||||
|
|
||||||
info, registered := ctx.Types[value.Type()]
|
|
||||||
if registered {
|
|
||||||
serialize = info.Serialize
|
|
||||||
}
|
|
||||||
|
|
||||||
if serialize == nil {
|
|
||||||
switch value.Type().Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
if value.Bool() {
|
|
||||||
data[0] = 0xFF
|
|
||||||
} else {
|
|
||||||
data[0] = 0x00
|
|
||||||
}
|
|
||||||
return 1, nil
|
|
||||||
|
|
||||||
case reflect.Int8:
|
|
||||||
data[0] = byte(value.Int())
|
|
||||||
return 1, nil
|
|
||||||
case reflect.Int16:
|
|
||||||
binary.BigEndian.PutUint16(data, uint16(value.Int()))
|
|
||||||
return 2, nil
|
|
||||||
case reflect.Int32:
|
|
||||||
binary.BigEndian.PutUint32(data, uint32(value.Int()))
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Int64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Int:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(value.Int()))
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.Uint8:
|
|
||||||
data[0] = byte(value.Uint())
|
|
||||||
return 1, nil
|
|
||||||
case reflect.Uint16:
|
|
||||||
binary.BigEndian.PutUint16(data, uint16(value.Uint()))
|
|
||||||
return 2, nil
|
|
||||||
case reflect.Uint32:
|
|
||||||
binary.BigEndian.PutUint32(data, uint32(value.Uint()))
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Uint64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Uint:
|
|
||||||
binary.BigEndian.PutUint64(data, value.Uint())
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.Float32:
|
|
||||||
binary.BigEndian.PutUint32(data, math.Float32bits(float32(value.Float())))
|
|
||||||
return 4, nil
|
|
||||||
case reflect.Float64:
|
|
||||||
binary.BigEndian.PutUint64(data, math.Float64bits(value.Float()))
|
|
||||||
return 8, nil
|
|
||||||
|
|
||||||
case reflect.String:
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(value.Len()))
|
|
||||||
copy(data[8:], []byte(value.String()))
|
|
||||||
return 8 + value.Len(), nil
|
|
||||||
|
|
||||||
case reflect.Pointer:
|
|
||||||
if value.IsNil() {
|
|
||||||
data[0] = 0x00
|
|
||||||
return 1, nil
|
|
||||||
} else {
|
|
||||||
data[0] = 0x01
|
|
||||||
written, err := SerializeValue(ctx, value.Elem(), data[1:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return 1 + written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Slice:
|
|
||||||
if value.IsNil() {
|
|
||||||
data[0] = 0x00
|
|
||||||
return 8, nil
|
|
||||||
} else {
|
|
||||||
data[0] = 0x01
|
|
||||||
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
|
||||||
total_written := 0
|
|
||||||
for i := 0; i < value.Len(); i++ {
|
|
||||||
written, err := SerializeValue(ctx, value.Index(i), data[9+total_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total_written += written
|
|
||||||
}
|
|
||||||
return 9 + total_written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Array:
|
|
||||||
total_written := 0
|
|
||||||
for i := 0; i < value.Len(); i++ {
|
|
||||||
written, err := SerializeValue(ctx, value.Index(i), data[total_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total_written += written
|
|
||||||
}
|
|
||||||
return total_written, nil
|
|
||||||
|
|
||||||
case reflect.Map:
|
|
||||||
if value.IsNil() {
|
|
||||||
data[0] = 0x00
|
|
||||||
return 1, nil
|
|
||||||
} else {
|
|
||||||
data[0] = 0x01
|
|
||||||
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
|
||||||
|
|
||||||
key := reflect.New(value.Type().Key()).Elem()
|
|
||||||
val := reflect.New(value.Type().Elem()).Elem()
|
|
||||||
iter := value.MapRange()
|
|
||||||
total_written := 0
|
|
||||||
for iter.Next() {
|
|
||||||
key.SetIterKey(iter)
|
|
||||||
val.SetIterValue(iter)
|
|
||||||
|
|
||||||
k, err := SerializeValue(ctx, key, data[9+total_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total_written += k
|
|
||||||
|
|
||||||
v, err := SerializeValue(ctx, val, data[9+total_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total_written += v
|
|
||||||
}
|
|
||||||
return 9 + total_written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
if registered == false {
|
|
||||||
return 0, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type())
|
|
||||||
} else {
|
|
||||||
binary.BigEndian.PutUint64(data, uint64(len(info.Fields)))
|
|
||||||
|
|
||||||
total_written := 0
|
|
||||||
for field_tag, field_info := range(info.Fields) {
|
|
||||||
binary.BigEndian.PutUint64(data[8+total_written:], uint64(field_tag))
|
|
||||||
total_written += 8
|
|
||||||
written, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index), data[8+total_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total_written += written
|
|
||||||
}
|
|
||||||
return 8 + total_written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Interface:
|
|
||||||
type_written, err := TypeStack(ctx, value.Elem().Type(), data)
|
|
||||||
|
|
||||||
elem_written, err := SerializeValue(ctx, value.Elem(), data[type_written:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return type_written + elem_written, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return serialize(ctx, value, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func split(data []byte, n int) ([]byte, []byte) {
|
|
||||||
return data[:n], data[n:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) {
|
|
||||||
var deserialize DeserializeFn = nil
|
|
||||||
|
|
||||||
info, registered := ctx.Types[t]
|
|
||||||
if registered {
|
|
||||||
deserialize = info.Deserialize
|
|
||||||
}
|
|
||||||
|
|
||||||
if deserialize == nil {
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
used, left := split(data, 1)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetBool(used[0] != 0x00)
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.Int8:
|
|
||||||
used, left := split(data, 1)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetInt(int64(used[0]))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Int16:
|
|
||||||
used, left := split(data, 2)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetInt(int64(binary.BigEndian.Uint16(used)))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Int32:
|
|
||||||
used, left := split(data, 4)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetInt(int64(binary.BigEndian.Uint32(used)))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Int64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Int:
|
|
||||||
used, left := split(data, 8)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetInt(int64(binary.BigEndian.Uint64(used)))
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.Uint8:
|
|
||||||
used, left := split(data, 1)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetUint(uint64(used[0]))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Uint16:
|
|
||||||
used, left := split(data, 2)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetUint(uint64(binary.BigEndian.Uint16(used)))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Uint32:
|
|
||||||
used, left := split(data, 4)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetUint(uint64(binary.BigEndian.Uint32(used)))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Uint64:
|
|
||||||
fallthrough
|
|
||||||
case reflect.Uint:
|
|
||||||
used, left := split(data, 8)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetUint(binary.BigEndian.Uint64(used))
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.Float32:
|
|
||||||
used, left := split(data, 4)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetFloat(float64(math.Float32frombits(binary.BigEndian.Uint32(used))))
|
|
||||||
return value, left, nil
|
|
||||||
case reflect.Float64:
|
|
||||||
used, left := split(data, 8)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetFloat(math.Float64frombits(binary.BigEndian.Uint64(used)))
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.String:
|
|
||||||
length, after_len := split(data, 8)
|
|
||||||
used, left := split(after_len, int(binary.BigEndian.Uint64(length)))
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
value.SetString(string(used))
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.Pointer:
|
|
||||||
flags, after_flags := split(data, 1)
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
if flags[0] == 0x00 {
|
|
||||||
value.SetZero()
|
|
||||||
return value, after_flags, nil
|
|
||||||
} else {
|
|
||||||
elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
value.Set(elem_value.Addr())
|
|
||||||
return value, after_elem, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Slice:
|
|
||||||
nil_byte := data[0]
|
|
||||||
data = data[1:]
|
|
||||||
if nil_byte == 0x00 {
|
|
||||||
return reflect.New(t).Elem(), data, nil
|
|
||||||
} else {
|
|
||||||
len_bytes, left := split(data, 8)
|
|
||||||
length := int(binary.BigEndian.Uint64(len_bytes))
|
|
||||||
value := reflect.MakeSlice(t, length, length)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
var elem_value reflect.Value
|
|
||||||
var err error
|
|
||||||
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
value.Index(i).Set(elem_value)
|
|
||||||
}
|
|
||||||
return value, left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Array:
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
left := data
|
|
||||||
for i := 0; i < t.Len(); i++ {
|
|
||||||
var elem_value reflect.Value
|
|
||||||
var err error
|
|
||||||
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
value.Index(i).Set(elem_value)
|
|
||||||
}
|
|
||||||
return value, left, nil
|
|
||||||
|
|
||||||
case reflect.Map:
|
|
||||||
flags, after_flags := split(data, 1)
|
|
||||||
if flags[0] == 0x00 {
|
|
||||||
return reflect.New(t).Elem(), after_flags, nil
|
|
||||||
} else {
|
|
||||||
len_bytes, left := split(after_flags, 8)
|
|
||||||
length := int(binary.BigEndian.Uint64(len_bytes))
|
|
||||||
|
|
||||||
value := reflect.MakeMapWithSize(t, length)
|
|
||||||
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
var key_value reflect.Value
|
|
||||||
var val_value reflect.Value
|
|
||||||
var err error
|
|
||||||
|
|
||||||
key_value, left, err = DeserializeValue(ctx, left, t.Key())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
val_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
value.SetMapIndex(key_value, val_value)
|
|
||||||
}
|
|
||||||
|
|
||||||
return value, left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
info, mapped := ctx.Types[t]
|
|
||||||
if mapped {
|
|
||||||
value := reflect.New(t).Elem()
|
|
||||||
|
|
||||||
num_field_bytes, left := split(data, 8)
|
|
||||||
num_fields := int(binary.BigEndian.Uint64(num_field_bytes))
|
|
||||||
|
|
||||||
for i := 0; i < num_fields; i++ {
|
|
||||||
var tag_bytes []byte
|
|
||||||
|
|
||||||
tag_bytes, left = split(left, 8)
|
|
||||||
field_tag := FieldTag(binary.BigEndian.Uint64(tag_bytes))
|
|
||||||
|
|
||||||
field_info, mapped := info.Fields[field_tag]
|
|
||||||
if mapped {
|
|
||||||
var field_val reflect.Value
|
|
||||||
var err error
|
|
||||||
field_val, left, err = DeserializeValue(ctx, left, field_info.Type)
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
value.FieldByIndex(field_info.Index).Set(field_val)
|
|
||||||
} else {
|
|
||||||
return reflect.Value{}, nil, fmt.Errorf("Unknown field %s on struct %s", field_tag, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info.PostDeserializeIndex != -1 {
|
|
||||||
post_deserialize_method := value.Addr().Method(info.PostDeserializeIndex)
|
|
||||||
post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
|
||||||
}
|
|
||||||
return value, left, nil
|
|
||||||
} else {
|
|
||||||
return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t)
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Interface:
|
|
||||||
elem_type, rest, err := UnwrapStack(ctx, data)
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
elem_val, left, err := DeserializeValue(ctx, rest, elem_type)
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
val := reflect.New(t).Elem()
|
|
||||||
val.Set(elem_val)
|
|
||||||
|
|
||||||
return val, left, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return deserialize(ctx, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,176 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"reflect"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testTypeStack[T any](t *testing.T, ctx *Context) {
|
|
||||||
buffer := [1024]byte{}
|
|
||||||
reflect_type := reflect.TypeFor[T]()
|
|
||||||
written, err := TypeStack(ctx, reflect_type, buffer[:])
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
stack := buffer[:written]
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
|
|
||||||
|
|
||||||
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
if len(rest) != 0 {
|
|
||||||
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
|
|
||||||
}
|
|
||||||
|
|
||||||
if unwrapped_type != reflect_type {
|
|
||||||
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeTypes(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test"})
|
|
||||||
|
|
||||||
testTypeStack[int](t, ctx)
|
|
||||||
testTypeStack[map[int]string](t, ctx)
|
|
||||||
testTypeStack[string](t, ctx)
|
|
||||||
testTypeStack[*string](t, ctx)
|
|
||||||
testTypeStack[*map[string]*map[*string]int](t, ctx)
|
|
||||||
testTypeStack[[5]int](t, ctx)
|
|
||||||
testTypeStack[uuid.UUID](t, ctx)
|
|
||||||
testTypeStack[NodeID](t, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
|
|
||||||
buffer := [1024]byte{}
|
|
||||||
written, err := Serialize(ctx, value, buffer[:])
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
serialized := buffer[:written]
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
|
||||||
|
|
||||||
deserialized, err := Deserialize[T](ctx, serialized)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
if value != deserialized {
|
|
||||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
|
|
||||||
buffer := [1024]byte{}
|
|
||||||
written, err := Serialize(ctx, value, buffer[:])
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
serialized := buffer[:written]
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
|
|
||||||
|
|
||||||
deserialized, err := Deserialize[L](ctx, serialized)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
for i, item := range(value) {
|
|
||||||
if item != deserialized[i] {
|
|
||||||
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
|
|
||||||
buffer := [1024]byte{}
|
|
||||||
|
|
||||||
written, err := Serialize(ctx, value, buffer[:])
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
serialized := buffer[:written]
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
|
|
||||||
|
|
||||||
deserialized, err := Deserialize[P](ctx, serialized)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
if value == nil && deserialized == nil {
|
|
||||||
ctx.Log.Logf("test", "Deserialized nil")
|
|
||||||
} else if value == nil {
|
|
||||||
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
|
|
||||||
} else if deserialized == nil {
|
|
||||||
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
|
|
||||||
} else if *deserialized != *value {
|
|
||||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
|
|
||||||
buffer := [1024]byte{}
|
|
||||||
written, err := Serialize(ctx, value, buffer[:])
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
serialized := buffer[:written]
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
|
||||||
|
|
||||||
deserialized, err := Deserialize[T](ctx, serialized)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeValues(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test"})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
|
|
||||||
|
|
||||||
testSerializeCompare[int8](t, ctx, -64)
|
|
||||||
testSerializeCompare[int16](t, ctx, -64)
|
|
||||||
testSerializeCompare[int32](t, ctx, -64)
|
|
||||||
testSerializeCompare[int64](t, ctx, -64)
|
|
||||||
testSerializeCompare[int](t, ctx, -64)
|
|
||||||
|
|
||||||
testSerializeCompare[uint8](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint16](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint32](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint64](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint](t, ctx, 64)
|
|
||||||
|
|
||||||
testSerializeCompare[string](t, ctx, "test")
|
|
||||||
|
|
||||||
a := 12
|
|
||||||
testSerializePointer[*int](t, ctx, &a)
|
|
||||||
|
|
||||||
b := "test"
|
|
||||||
testSerializePointer[*string](t, ctx, nil)
|
|
||||||
testSerializePointer[*string](t, ctx, &b)
|
|
||||||
|
|
||||||
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
|
|
||||||
|
|
||||||
testSerializeCompare[bool](t, ctx, true)
|
|
||||||
testSerializeCompare[bool](t, ctx, false)
|
|
||||||
testSerializeCompare[int](t, ctx, -1)
|
|
||||||
testSerializeCompare[uint](t, ctx, 1)
|
|
||||||
testSerializeCompare[NodeID](t, ctx, RandID())
|
|
||||||
testSerializeCompare[*int](t, ctx, nil)
|
|
||||||
testSerializeCompare(t, ctx, "string")
|
|
||||||
|
|
||||||
testSerialize(t, ctx, map[string]string{
|
|
||||||
"Test": "Test",
|
|
||||||
"key": "String",
|
|
||||||
"": "",
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerialize[map[string]string](t, ctx, nil)
|
|
||||||
|
|
||||||
testSerialize(t, ctx, NewListenerExt(10))
|
|
||||||
|
|
||||||
node, err := ctx.NewNode(nil, "Node")
|
|
||||||
fatalErr(t, err)
|
|
||||||
testSerialize(t, ctx, node)
|
|
||||||
}
|
|
@ -0,0 +1,736 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
"sync"
|
||||||
|
"errors"
|
||||||
|
"encoding/json"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
|
||||||
|
type ThreadActions map[string]ThreadAction
|
||||||
|
type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error)
|
||||||
|
type ThreadHandlers map[SignalType]ThreadHandler
|
||||||
|
|
||||||
|
type InfoType string
|
||||||
|
func (t InfoType) String() string {
|
||||||
|
return string(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Info interface {
|
||||||
|
Serializable[InfoType]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data required by a parent thread to restore it's children
|
||||||
|
type ParentInfo struct {
|
||||||
|
Start bool `json:"start"`
|
||||||
|
StartAction string `json:"start_action"`
|
||||||
|
RestoreAction string `json:"restore_action"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const ParentInfoType = InfoType("PARENT")
|
||||||
|
func (info *ParentInfo) Type() InfoType {
|
||||||
|
return ParentInfoType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *ParentInfo) Serialize() ([]byte, error) {
|
||||||
|
return json.MarshalIndent(info, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueuedAction struct {
|
||||||
|
Timeout time.Time `json:"time"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreadType string
|
||||||
|
func (thread ThreadType) Hash() uint64 {
|
||||||
|
hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread))))
|
||||||
|
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreadInfo struct {
|
||||||
|
Actions ThreadActions
|
||||||
|
Handlers ThreadHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
type InfoLoadFunc func([]byte)(Info, error)
|
||||||
|
type ThreadExtContext struct {
|
||||||
|
Types map[ThreadType]ThreadInfo
|
||||||
|
Loads map[InfoType]InfoLoadFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseThreadType = ThreadType("BASE")
|
||||||
|
func NewThreadExtContext() *ThreadExtContext {
|
||||||
|
return &ThreadExtContext{
|
||||||
|
Types: map[ThreadType]ThreadInfo{
|
||||||
|
BaseThreadType: ThreadInfo{
|
||||||
|
Actions: BaseThreadActions,
|
||||||
|
Handlers: BaseThreadHandlers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Loads: map[InfoType]InfoLoadFunc{
|
||||||
|
ParentInfoType: func(data []byte) (Info, error) {
|
||||||
|
var info ParentInfo
|
||||||
|
err := json.Unmarshal(data, &info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &info, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error {
|
||||||
|
if actions == nil || handlers == nil {
|
||||||
|
return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, exists := ctx.Types[thread_type]
|
||||||
|
if exists == true {
|
||||||
|
return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type)
|
||||||
|
}
|
||||||
|
ctx.Types[thread_type] = ThreadInfo{
|
||||||
|
Actions: actions,
|
||||||
|
Handlers: handlers,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error {
|
||||||
|
if load_fn == nil {
|
||||||
|
return fmt.Errorf("Cannot register %s with nil load_fn", info_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, exists := ctx.Loads[info_type]
|
||||||
|
if exists == true {
|
||||||
|
return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Loads[info_type] = load_fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreadExt struct {
|
||||||
|
Actions ThreadActions
|
||||||
|
Handlers ThreadHandlers
|
||||||
|
|
||||||
|
ThreadType ThreadType
|
||||||
|
|
||||||
|
SignalChan chan Signal
|
||||||
|
TimeoutChan <-chan time.Time
|
||||||
|
|
||||||
|
ChildWaits sync.WaitGroup
|
||||||
|
|
||||||
|
ActiveLock sync.Mutex
|
||||||
|
Active bool
|
||||||
|
State string
|
||||||
|
|
||||||
|
Parent *Node
|
||||||
|
Children map[NodeID]ChildInfo
|
||||||
|
|
||||||
|
ActionQueue []QueuedAction
|
||||||
|
NextAction *QueuedAction
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreadExtJSON struct {
|
||||||
|
State string `json:"state"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Parent string `json:"parent"`
|
||||||
|
Children map[string]map[string][]byte `json:"children"`
|
||||||
|
ActionQueue []QueuedAction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) Serialize() ([]byte, error) {
|
||||||
|
children := map[string]map[string][]byte{}
|
||||||
|
for id, child := range(ext.Children) {
|
||||||
|
id_str := id.String()
|
||||||
|
children[id_str] = map[string][]byte{}
|
||||||
|
for info_type, info := range(child.Infos) {
|
||||||
|
var err error
|
||||||
|
children[id_str][string(info_type)], err = info.Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.MarshalIndent(&ThreadExtJSON{
|
||||||
|
State: ext.State,
|
||||||
|
Type: string(ext.ThreadType),
|
||||||
|
Parent: SaveNode(ext.Parent),
|
||||||
|
Children: children,
|
||||||
|
ActionQueue: ext.ActionQueue,
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) {
|
||||||
|
if children == nil {
|
||||||
|
children = map[NodeID]ChildInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if action_queue == nil {
|
||||||
|
action_queue = []QueuedAction{}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
type_info, exists := thread_ctx.Types[thread_type]
|
||||||
|
if exists == false {
|
||||||
|
return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type)
|
||||||
|
}
|
||||||
|
next_action, timeout_chan := SoonestAction(action_queue)
|
||||||
|
|
||||||
|
return &ThreadExt{
|
||||||
|
ThreadType: thread_type,
|
||||||
|
Actions: type_info.Actions,
|
||||||
|
Handlers: type_info.Handlers,
|
||||||
|
SignalChan: make(chan Signal, THREAD_BUFFER_SIZE),
|
||||||
|
TimeoutChan: timeout_chan,
|
||||||
|
Active: false,
|
||||||
|
State: state,
|
||||||
|
Parent: parent,
|
||||||
|
Children: children,
|
||||||
|
ActionQueue: action_queue,
|
||||||
|
NextAction: next_action,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const THREAD_BUFFER_SIZE int = 1024
|
||||||
|
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
|
var j ThreadExtJSON
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ctx.Log.Logf("db", "DB_LOAD_THREAD_EXT_JSON: %+v", j)
|
||||||
|
|
||||||
|
parent, err := RestoreNode(ctx, j.Parent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
children := map[NodeID]ChildInfo{}
|
||||||
|
for id_str, infos := range(j.Children) {
|
||||||
|
child_node, err := RestoreNode(ctx, id_str)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
child_infos := map[InfoType]Info{}
|
||||||
|
for info_type_str, info_data := range(infos) {
|
||||||
|
info_type := InfoType(info_type_str)
|
||||||
|
info_load, exists := thread_ctx.Loads[info_type]
|
||||||
|
if exists == false {
|
||||||
|
return nil, fmt.Errorf("%s is not a known InfoType in ThreacExrContxt", info_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := info_load(info_data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
child_infos[info_type] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
children[child_node.ID] = ChildInfo{
|
||||||
|
Child: child_node,
|
||||||
|
Infos: child_infos,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
const ThreadExtType = ExtType("THREAD")
|
||||||
|
func (ext *ThreadExt) Type() ExtType {
|
||||||
|
return ThreadExtType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
|
||||||
|
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
|
||||||
|
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) ClearActionQueue() {
|
||||||
|
ext.ActionQueue = []QueuedAction{}
|
||||||
|
ext.NextAction = nil
|
||||||
|
ext.TimeoutChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
|
||||||
|
var soonest_action *QueuedAction
|
||||||
|
var soonest_time time.Time
|
||||||
|
for _, action := range(actions) {
|
||||||
|
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
|
||||||
|
soonest_action = &action
|
||||||
|
soonest_time = action.Timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if soonest_action != nil {
|
||||||
|
return soonest_action, time.After(time.Until(soonest_action.Timeout))
|
||||||
|
} else {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) ChildList() []*Node {
|
||||||
|
ret := make([]*Node, len(ext.Children))
|
||||||
|
i := 0
|
||||||
|
for _, info := range(ext.Children) {
|
||||||
|
ret[i] = info.Child
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumed that thread is already locked for signal
|
||||||
|
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
|
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch signal.Direction() {
|
||||||
|
case Up:
|
||||||
|
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
||||||
|
if ext.Parent != nil {
|
||||||
|
if ext.Parent.ID != node.ID {
|
||||||
|
return SendSignal(context, ext.Parent, node, signal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
case Down:
|
||||||
|
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
||||||
|
for _, info := range(ext.Children) {
|
||||||
|
err := SendSignal(context, info.Child, node, signal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
case Direct:
|
||||||
|
err = nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Invalid signal direction %d", signal.Direction())
|
||||||
|
}
|
||||||
|
ext.SignalChan <- signal
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
child_ext, err := GetExt[*ThreadExt](child)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return UpdateStates(context, principal, ACLMap{
|
||||||
|
thread.ID: ACLInfo{thread, []string{"children"}},
|
||||||
|
child.ID: ACLInfo{child, []string{"parent"}},
|
||||||
|
}, func(context *StateContext) error {
|
||||||
|
_, is_child := thread_ext.Children[child.ID]
|
||||||
|
if is_child == false {
|
||||||
|
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(thread_ext.Children, child.ID)
|
||||||
|
child_ext.Parent = nil
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
|
||||||
|
for _, info := range(cur.Children) {
|
||||||
|
child := info.Child
|
||||||
|
if child.ID == id {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
child_ext, err := GetExt[*ThreadExt](child)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var is_child bool
|
||||||
|
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
|
||||||
|
is_child, err = checkIfChild(context, id, child_ext)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if is_child {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Links child to parent with info as the associated info
|
||||||
|
// Continues the write context with princ, getting children for thread and parent for child
|
||||||
|
func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
|
||||||
|
if context == nil || principal == nil || thread == nil || info.Child == nil {
|
||||||
|
return fmt.Errorf("invalid input")
|
||||||
|
}
|
||||||
|
|
||||||
|
child := info.Child
|
||||||
|
if thread.ID == child.ID {
|
||||||
|
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
child_ext, err := GetExt[*ThreadExt](child)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return UpdateStates(context, principal, ACLMap{
|
||||||
|
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
|
||||||
|
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
|
||||||
|
}, func(context *StateContext) error {
|
||||||
|
if child_ext.Parent != nil {
|
||||||
|
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
is_child, err := checkIfChild(context, thread.ID, child_ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if is_child == true {
|
||||||
|
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
is_child, err = checkIfChild(context, child.ID, thread_ext)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return err
|
||||||
|
} else if is_child == true {
|
||||||
|
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO check for info types
|
||||||
|
|
||||||
|
thread_ext.Children[child.ID] = info
|
||||||
|
child_ext.Parent = thread
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChildInfo struct {
|
||||||
|
Child *Node
|
||||||
|
Infos map[InfoType]Info
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo {
|
||||||
|
if infos == nil {
|
||||||
|
infos = map[InfoType]Info{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ChildInfo{
|
||||||
|
Child: child,
|
||||||
|
Infos: infos,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) SetActive(active bool) error {
|
||||||
|
ext.ActiveLock.Lock()
|
||||||
|
defer ext.ActiveLock.Unlock()
|
||||||
|
if ext.Active == true && active == true {
|
||||||
|
return fmt.Errorf("alreday active, cannot set active")
|
||||||
|
} else if ext.Active == false && active == false {
|
||||||
|
return fmt.Errorf("already inactive, canot set inactive")
|
||||||
|
}
|
||||||
|
ext.Active = active
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ThreadExt) SetState(state string) error {
|
||||||
|
ext.State = state
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Requires the read permission of threads children
|
||||||
|
func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
|
||||||
|
if thread == nil {
|
||||||
|
panic("cannot recurse through nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == thread.ID {
|
||||||
|
return thread, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var found *Node = nil
|
||||||
|
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||||
|
for _, info := range(thread_ext.Children) {
|
||||||
|
found, err = FindChild(context, principal, info.Child, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if found != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return found, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
|
||||||
|
thread_ext.ChildWaits.Add(1)
|
||||||
|
go func(child *Node) {
|
||||||
|
defer thread_ext.ChildWaits.Done()
|
||||||
|
err := ThreadLoop(ctx, child, first_action)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
|
||||||
|
}
|
||||||
|
}(child)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
|
||||||
|
func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
|
||||||
|
thread_ext, err := GetExt[*ThreadExt](thread)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
|
||||||
|
|
||||||
|
err = thread_ext.SetActive(true)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
next_action := first_action
|
||||||
|
for next_action != "" {
|
||||||
|
action, exists := thread_ext.Actions[next_action]
|
||||||
|
if exists == false {
|
||||||
|
error_str := fmt.Sprintf("%s is not a valid action", next_action)
|
||||||
|
return errors.New(error_str)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
|
||||||
|
next_action, err = action(ctx, thread, thread_ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = thread_ext.SetActive(false)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||||
|
sig, ok := signal.(IDSignal)
|
||||||
|
if ok == false {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
info, exists := thread_ext.Children[sig.ID]
|
||||||
|
if exists == false {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
||||||
|
if exists == false {
|
||||||
|
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
|
||||||
|
}
|
||||||
|
|
||||||
|
if parent_info.Start == true {
|
||||||
|
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
return "wait", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to start a child from a thread during a signal handler
|
||||||
|
// Starts a write context, so cannot be called from either a write or read context
|
||||||
|
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||||
|
sig, ok := signal.(StartChildSignal)
|
||||||
|
if ok == false {
|
||||||
|
return "wait", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||||
|
info, exists:= thread_ext.Children[sig.ID]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
|
||||||
|
}
|
||||||
|
parent_info.Start = true
|
||||||
|
ChildGo(ctx, thread_ext, info.Child, sig.Action)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to restore threads that should be running from a parents restore action
|
||||||
|
// Starts a write context, so cannot be called from either a write or read context
|
||||||
|
func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
||||||
|
return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"state"}), func(context *StateContext) error {
|
||||||
|
for _, info := range(thread_ext.Children) {
|
||||||
|
child_ext, err := GetExt[*ThreadExt](info.Child)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
parent_info := info.Infos[ParentInfoType].(*ParentInfo)
|
||||||
|
if parent_info.Start == true && child_ext.State != "finished" {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
|
||||||
|
if start == true {
|
||||||
|
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
||||||
|
} else {
|
||||||
|
ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to be called during a threads start action, sets the thread state to started
|
||||||
|
// Starts a write context, so cannot be called from either a write or read context
|
||||||
|
// Returns "wait", nil on success, so the first return value can be ignored safely
|
||||||
|
func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
||||||
|
err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return thread_ext.SetState("started")
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
context = NewReadContext(ctx)
|
||||||
|
return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case signal := <- thread_ext.SignalChan:
|
||||||
|
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
|
||||||
|
signal_fn, exists := thread_ext.Handlers[signal.Type()]
|
||||||
|
if exists == true {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
|
||||||
|
return signal_fn(ctx, thread, thread_ext, signal)
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
|
||||||
|
}
|
||||||
|
case <- thread_ext.TimeoutChan:
|
||||||
|
timeout_action := ""
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
|
||||||
|
timeout_action = thread_ext.NextAction.Action
|
||||||
|
thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
|
||||||
|
}
|
||||||
|
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
|
||||||
|
return timeout_action, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
||||||
|
context := NewWriteContext(ctx)
|
||||||
|
return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
||||||
|
err := thread_ext.SetState("finished")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var ThreadAbortedError = errors.New("Thread aborted by signal")
|
||||||
|
|
||||||
|
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
|
||||||
|
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||||
|
context := NewReadContext(ctx)
|
||||||
|
err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "", ThreadAbortedError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default thread action for "stop", sends a signal and returns no error
|
||||||
|
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
||||||
|
context := NewReadContext(ctx)
|
||||||
|
err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
|
||||||
|
return "finish", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default thread actions
|
||||||
|
var BaseThreadActions = ThreadActions{
|
||||||
|
"wait": ThreadWait,
|
||||||
|
"start": ThreadStart,
|
||||||
|
"finish": ThreadFinish,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default thread signal handlers
|
||||||
|
var BaseThreadHandlers = ThreadHandlers{
|
||||||
|
"abort": ThreadAbort,
|
||||||
|
"stop": ThreadStop,
|
||||||
|
}
|
@ -0,0 +1,120 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
"fmt"
|
||||||
|
"encoding/json"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ECDHExt struct {
|
||||||
|
Granted time.Time
|
||||||
|
Pubkey *ecdsa.PublicKey
|
||||||
|
Shared []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ECDHExtJSON struct {
|
||||||
|
Granted time.Time `json:"granted"`
|
||||||
|
Pubkey []byte `json:"pubkey"`
|
||||||
|
Shared []byte `json:"shared"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const ECDHExtType = ExtType("ECDH")
|
||||||
|
func (ext *ECDHExt) Type() ExtType {
|
||||||
|
return ECDHExtType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ECDHExt) Serialize() ([]byte, error) {
|
||||||
|
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.MarshalIndent(&ECDHExtJSON{
|
||||||
|
Granted: ext.Granted,
|
||||||
|
Pubkey: pubkey,
|
||||||
|
Shared: ext.Shared,
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
|
var j ECDHExtJSON
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pubkey *ecdsa.PublicKey
|
||||||
|
switch pub.(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
pubkey = pub.(*ecdsa.PublicKey)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Invalid key type: %+v", pub)
|
||||||
|
}
|
||||||
|
|
||||||
|
extension := ECDHExt{
|
||||||
|
Granted: j.Granted,
|
||||||
|
Pubkey: pubkey,
|
||||||
|
Shared: j.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &extension, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupExt struct {
|
||||||
|
Members NodeMap
|
||||||
|
}
|
||||||
|
|
||||||
|
const GroupExtType = ExtType("GROUP")
|
||||||
|
func (ext *GroupExt) Type() ExtType {
|
||||||
|
return GroupExtType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *GroupExt) Serialize() ([]byte, error) {
|
||||||
|
return json.MarshalIndent(&struct{
|
||||||
|
Members []string `json:"members"`
|
||||||
|
}{
|
||||||
|
Members: SaveNodeList(ext.Members),
|
||||||
|
}, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGroupExt(members NodeMap) *GroupExt {
|
||||||
|
if members == nil {
|
||||||
|
members = NodeMap{}
|
||||||
|
}
|
||||||
|
return &GroupExt{
|
||||||
|
Members: members,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
|
||||||
|
var j struct {
|
||||||
|
Members []string `json:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal(data, &j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
members, err := RestoreNodeList(ctx, j.Members)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewGroupExt(members), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error {
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue