Compare commits
242 Commits
graph-rewo
...
master
Author | SHA1 | Date |
---|---|---|
noah metz | 8cb97d2350 | |
noah metz | c29981da20 | |
noah metz | 810e17990c | |
noah metz | d0d07d464d | |
noah metz | 11e7df2bde | |
noah metz | 3eee736f97 | |
noah metz | 7e157068d7 | |
noah metz | b2d84b2453 | |
noah metz | 66d5e3f260 | |
noah metz | 1eff534e1a | |
noah metz | 3d28c703db | |
noah metz | a4115a4f99 | |
noah metz | ab76f09923 | |
noah metz | 6850031e80 | |
noah metz | 0b93c90aa9 | |
noah metz | 2db4655670 | |
noah metz | d7b07df798 | |
noah metz | 0bced58fd1 | |
noah metz | 8f9a759b26 | |
noah metz | c0407b094c | |
noah metz | c591fa5ace | |
noah metz | f8dad12fdb | |
noah metz | eef8451566 | |
noah metz | 7e143c9d93 | |
noah metz | 7314c74087 | |
noah metz | 1eb6479169 | |
noah metz | e16bec3997 | |
noah metz | 6942dc02db | |
noah metz | e5776e0a14 | |
noah metz | 8927077167 | |
noah metz | faab7eb52c | |
noah metz | 61565fa18c | |
noah metz | 9c534a1d33 | |
noah metz | c33f37e4cd | |
noah metz | b42753b575 | |
noah metz | df9707309f | |
noah metz | 646e6592f0 | |
noah metz | c5f95d431d | |
noah metz | 59424cecc1 | |
noah metz | 786f374b5f | |
noah metz | c54101f4a3 | |
noah metz | 520219f7b0 | |
noah metz | 8b91d0af0c | |
noah metz | 8bb1dacf23 | |
noah metz | 6580725241 | |
noah metz | 58675377fd | |
noah metz | dee7b917f7 | |
noah metz | 8d1e273331 | |
noah metz | 42e4a8f7ea | |
noah metz | c773ea2b14 | |
noah metz | ddb3854d00 | |
noah metz | f41160da68 | |
noah metz | 3052f8099f | |
noah metz | f5a08bbc48 | |
noah metz | 57156251cd | |
noah metz | 8e3510129c | |
noah metz | 3a6e562390 | |
noah metz | a061d6850c | |
noah metz | 2081771135 | |
noah metz | dbe819fd05 | |
noah metz | 39d3288094 | |
noah metz | 0e8590e22a | |
noah metz | 84aee24a21 | |
noah metz | 8a973c38b5 | |
noah metz | 0eab243659 | |
noah metz | a568adc156 | |
noah metz | 266e353c5f | |
noah metz | 95939fb020 | |
noah metz | 92bb4bf976 | |
noah metz | 193952e84d | |
noah metz | d930d78351 | |
noah metz | e299e77e78 | |
noah metz | b9bf61cf68 | |
noah metz | 4b7bc93914 | |
noah metz | 0159d0dd5a | |
noah metz | 76e1e9a17a | |
noah metz | b3bbf71c22 | |
noah metz | 08c36e0505 | |
noah metz | c4df57a932 | |
noah metz | 8c80ec9dd6 | |
noah metz | 187ffb1324 | |
noah metz | 16e25c009f | |
noah metz | c63ad91252 | |
noah metz | b32f264879 | |
noah metz | ae289705bb | |
noah metz | 190824e710 | |
noah metz | c4de49099b | |
noah metz | 0a936f50f8 | |
noah metz | 0941c6c64e | |
noah metz | 92d8dfd006 | |
noah metz | f82bbabc66 | |
noah metz | 542c5c18af | |
noah metz | 6381713972 | |
noah metz | 34162023cb | |
noah metz | 56f3cce415 | |
noah metz | 7234b11643 | |
noah metz | 302f0f42fe | |
noah metz | 7451e8e960 | |
noah metz | 9eadb00397 | |
noah metz | e042384b3f | |
noah metz | 3ef0a98a17 | |
noah metz | bb28d9bc32 | |
noah metz | 96408259d1 | |
noah metz | 5a86334d5a | |
noah metz | e93fe50b5f | |
noah metz | e013edc656 | |
noah metz | d4e0d855c7 | |
noah metz | 0fc6215448 | |
noah metz | b09e150c46 | |
noah metz | d86d424cd7 | |
noah metz | ff7046badf | |
noah metz | ab5b922a7d | |
noah metz | d34304f6ad | |
noah metz | 9ffa9d6cb2 | |
noah metz | 21224e8837 | |
noah metz | 6bfe339854 | |
noah metz | ecaf35f05d | |
noah metz | eb30b477d5 | |
noah metz | 5c70d1b18d | |
noah metz | de1a229db6 | |
noah metz | de54c87e43 | |
noah metz | dac0f1f273 | |
noah metz | 07ce005365 | |
noah metz | c4e5054e07 | |
noah metz | b47c95c5ad | |
noah metz | 045304f9f6 | |
noah metz | d0f0fb1b82 | |
noah metz | fa5facc5fc | |
noah metz | b766aadef9 | |
noah metz | 47107dec1c | |
noah metz | 15793e1415 | |
noah metz | e2f34150ef | |
noah metz | 06513a5ad6 | |
noah metz | 799b6404dd | |
noah metz | 1888cf428d | |
noah metz | 857f04efe3 | |
noah metz | 4daec4d601 | |
noah metz | 7bed89701d | |
noah metz | ba344bddcf | |
noah metz | 4ce2a642c5 | |
noah metz | f398c9659e | |
noah metz | 98f05d57f9 | |
noah metz | 98c0b7e807 | |
noah metz | b446c9078a | |
noah metz | d663314def | |
noah metz | e26ddcae37 | |
noah metz | 20c7a38044 | |
noah metz | f31beade29 | |
noah metz | 96e842decf | |
noah metz | 8770d6f433 | |
noah metz | 1d91854f6f | |
noah metz | 7d0af0eb5b | |
noah metz | 0f7a0debd6 | |
noah metz | 147f44e5ff | |
noah metz | 42cd8f4188 | |
noah metz | b9a2cceaf1 | |
noah metz | df09433b88 | |
noah metz | 47151905a0 | |
noah metz | 3a53163f36 | |
noah metz | c515128743 | |
noah metz | 5f2b97a75b | |
noah metz | 5bef8d96ba | |
noah metz | ed9c353b95 | |
noah metz | 3bc427f2a9 | |
noah metz | 6895aa7c8e | |
noah metz | 0cc7174667 | |
noah metz | 059c36663b | |
noah metz | b06c741ee3 | |
noah metz | 42597057af | |
noah metz | 09c25b1e48 | |
noah metz | e7d94414d5 | |
noah metz | d8355ab786 | |
noah metz | 887a976263 | |
noah metz | 771bf53356 | |
noah metz | 0313d6a33f | |
noah metz | 064dc72820 | |
noah metz | 0424a3970f | |
noah metz | fde2f3ddd4 | |
noah metz | 1af94520a8 | |
noah metz | 4e31a6763d | |
noah metz | b4e6123d4c | |
noah metz | e34d3ec981 | |
noah metz | 31f6c12f14 | |
noah metz | ec9a29573a | |
noah metz | 395a75fcb8 | |
noah metz | 9f9e65cf54 | |
noah metz | da58b04774 | |
noah metz | 7810859bca | |
noah metz | ad6ea0cc59 | |
noah metz | 96c2b84b6f | |
noah metz | 79e40bf3f3 | |
noah metz | f56f92a58b | |
noah metz | e92b2e508d | |
noah metz | dca4de183e | |
noah metz | fad8d8123c | |
noah metz | 891e69c775 | |
noah metz | d6a35247b0 | |
noah metz | 1a3a07336a | |
noah metz | 2dfa10b1f6 | |
noah metz | c60393d390 | |
noah metz | 641bd8febe | |
noah metz | f87571edcf | |
noah metz | 5fb1cb6d17 | |
noah metz | b92cebbe74 | |
noah metz | 5678c79798 | |
noah metz | 5f409def03 | |
noah metz | 61de2669e2 | |
noah metz | 08288f88af | |
noah metz | fc69bc3d0d | |
noah metz | b3de3144cc | |
noah metz | 27687add1b | |
noah metz | a16cf6bb38 | |
noah metz | fb7e6d02f4 | |
noah metz | d40e561728 | |
noah metz | f314b46415 | |
noah metz | 027c3d4c96 | |
noah metz | c763725a34 | |
noah metz | 027bb74887 | |
noah metz | a1ce4238cc | |
noah metz | a44b00bc97 | |
noah metz | 200e19eea7 | |
noah metz | 98893de442 | |
noah metz | 78c29d2f74 | |
noah metz | 7ebb519cd0 | |
noah metz | 9d31394707 | |
noah metz | 6b375245df | |
noah metz | 7965f8fbe6 | |
noah metz | d729698523 | |
noah metz | 26d122e3c5 | |
noah metz | 3ad969a5ca | |
noah metz | 7a7a9c95a3 | |
noah metz | c62ef57fe7 | |
noah metz | 544264f06b | |
noah metz | c34d717b52 | |
noah metz | 95a2f46d28 | |
noah metz | 70baca9e9c | |
noah metz | 4fa88dc056 | |
noah metz | 81c2e11304 | |
noah metz | cc807b3982 | |
noah metz | 186123ce01 | |
noah metz | 494d212051 | |
noah metz | 34082630b2 |
@ -0,0 +1,3 @@
|
|||||||
|
[submodule "graphql"]
|
||||||
|
path = graphql
|
||||||
|
url = https://github.com/graphql-go/graphql
|
@ -0,0 +1,48 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
badger "github.com/dgraph-io/badger/v3"
|
||||||
|
gv "github.com/mekkanized/graphvent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func check(err error) {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
ctx, err := gv.NewContext(&gv.BadgerDB{
|
||||||
|
DB: db,
|
||||||
|
}, gv.NewConsoleLogger([]string{"test"}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
listener_ext := gv.NewListenerExt(1000)
|
||||||
|
|
||||||
|
n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
_, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
for true {
|
||||||
|
select {
|
||||||
|
case message := <- listener_ext.Chan:
|
||||||
|
fmt.Printf("Listener Message: %+v\n", message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,284 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
badger "github.com/dgraph-io/badger/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database interface {
|
||||||
|
WriteNodeInit(*Context, *Node) error
|
||||||
|
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
|
||||||
|
LoadNode(*Context, NodeID) (*Node, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const WRITE_BUFFER_SIZE = 1000000
|
||||||
|
type BadgerDB struct {
|
||||||
|
*badger.DB
|
||||||
|
sync.Mutex
|
||||||
|
buffer [WRITE_BUFFER_SIZE]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("Cannot serialize nil *Node")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Update(func(tx *badger.Txn) error {
|
||||||
|
db.Lock()
|
||||||
|
defer db.Unlock()
|
||||||
|
|
||||||
|
// Get the base key bytes
|
||||||
|
id_ser, err := node.ID.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur := 0
|
||||||
|
|
||||||
|
// Write Node value
|
||||||
|
written, err := Serialize(ctx, node, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Set(id_ser, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur += written
|
||||||
|
|
||||||
|
// Write empty signal queue
|
||||||
|
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||||
|
written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur += written
|
||||||
|
|
||||||
|
// Write node extension list
|
||||||
|
ext_list := []ExtType{}
|
||||||
|
for ext_type := range(node.Extensions) {
|
||||||
|
ext_list = append(ext_list, ext_type)
|
||||||
|
}
|
||||||
|
written, err = Serialize(ctx, ext_list, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||||
|
err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cur += written
|
||||||
|
|
||||||
|
// For each extension:
|
||||||
|
for ext_type, ext := range(node.Extensions) {
|
||||||
|
ext_info, exists := ctx.Extensions[ext_type]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
|
||||||
|
}
|
||||||
|
|
||||||
|
ext_value := reflect.ValueOf(ext).Elem()
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||||
|
|
||||||
|
// Write each field to a seperate key
|
||||||
|
for field_tag, field_info := range(ext_info.Fields) {
|
||||||
|
field_value := ext_value.FieldByIndex(field_info.Index)
|
||||||
|
|
||||||
|
field_id := make([]byte, len(ext_id) + 8)
|
||||||
|
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
||||||
|
copy(field_id, tmp)
|
||||||
|
|
||||||
|
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
||||||
|
}
|
||||||
|
cur += written
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
|
||||||
|
return db.Update(func(tx *badger.Txn) error {
|
||||||
|
db.Lock()
|
||||||
|
defer db.Unlock()
|
||||||
|
|
||||||
|
// Get the base key bytes
|
||||||
|
id_bytes := ([16]byte)(node.ID)
|
||||||
|
|
||||||
|
cur := 0
|
||||||
|
|
||||||
|
// Write the signal queue if it needs to be written
|
||||||
|
if node.writeSignalQueue {
|
||||||
|
node.writeSignalQueue = false
|
||||||
|
|
||||||
|
sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
|
||||||
|
written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
|
||||||
|
}
|
||||||
|
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
|
||||||
|
}
|
||||||
|
cur += written
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each ext in changes
|
||||||
|
for ext_type, changes := range(changes) {
|
||||||
|
ext_info, exists := ctx.Extensions[ext_type]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("%s is not an extension in ctx", ext_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
ext, exists := node.Extensions[ext_type]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
|
||||||
|
}
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
|
||||||
|
ext_value := reflect.ValueOf(ext)
|
||||||
|
|
||||||
|
// Write each field
|
||||||
|
for _, tag := range(changes) {
|
||||||
|
field_info, exists := ext_info.Fields[tag]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
field_value := ext_value.FieldByIndex(field_info.Index)
|
||||||
|
|
||||||
|
field_id := make([]byte, len(ext_id) + 8)
|
||||||
|
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
|
||||||
|
copy(field_id, tmp)
|
||||||
|
|
||||||
|
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Set(field_id, db.buffer[cur:cur+written])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
|
||||||
|
}
|
||||||
|
cur += written
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
|
||||||
|
var node *Node = nil
|
||||||
|
|
||||||
|
err := db.View(func(tx *badger.Txn) error {
|
||||||
|
// Get the base key bytes
|
||||||
|
id_ser, err := id.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to serialize node_id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the node value
|
||||||
|
node_item, err := tx.Get(id_ser)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = node_item.Value(func(val []byte) error {
|
||||||
|
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
|
||||||
|
node, err = Deserialize[*Node](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the signal queue
|
||||||
|
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||||
|
sigqueue_item, err := tx.Get(sigqueue_id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to get sigqueue_id: %w", err)
|
||||||
|
}
|
||||||
|
err = sigqueue_item.Value(func(val []byte) error {
|
||||||
|
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the extension list
|
||||||
|
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||||
|
ext_list_item, err := tx.Get(ext_list_id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ext_list []ExtType
|
||||||
|
ext_list_item.Value(func(val []byte) error {
|
||||||
|
ext_list, err = Deserialize[[]ExtType](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get the extensions
|
||||||
|
for _, ext_type := range(ext_list) {
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||||
|
ext_info, exists := ctx.Extensions[ext_type]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("Extension %s not in context", ext_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := reflect.New(ext_info.Type)
|
||||||
|
for field_tag, field_info := range(ext_info.Fields) {
|
||||||
|
field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
|
||||||
|
field_item, err := tx.Get(field_id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
|
||||||
|
}
|
||||||
|
err = field_item.Value(func(val []byte) error {
|
||||||
|
value, _, err := DeserializeValue(ctx, val, field_info.Type)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ext.Elem().FieldByIndex(field_info.Index).Set(value)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node.Extensions[ext_type] = ext.Interface().(Extension)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if node == nil {
|
||||||
|
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, nil
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
type Tag string
|
||||||
|
type Changes []Tag
|
||||||
|
|
||||||
|
// Extensions are data attached to nodes that process signals
|
||||||
|
type Extension interface {
|
||||||
|
// Called to process incoming signals, returning changes and messages to send
|
||||||
|
Process(*Context, *Node, NodeID, Signal) ([]Message, Changes)
|
||||||
|
|
||||||
|
// Called when the node is loaded into a context(creation or move), so extension data can be initialized
|
||||||
|
Load(*Context, *Node) error
|
||||||
|
|
||||||
|
// Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up
|
||||||
|
Unload(*Context, *Node)
|
||||||
|
}
|
@ -1,172 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
"reflect"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewField(init func()*graphql.Field) *graphql.Field {
|
|
||||||
return init()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Singleton[K graphql.Type] struct {
|
|
||||||
Type K
|
|
||||||
List *graphql.List
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List)) *Singleton[K] {
|
|
||||||
val := init()
|
|
||||||
list := graphql.NewList(val)
|
|
||||||
if post_init != nil {
|
|
||||||
post_init(val, list)
|
|
||||||
}
|
|
||||||
return &Singleton[K]{
|
|
||||||
Type: val,
|
|
||||||
List: list,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddNodeInterfaceFields(i *graphql.Interface) {
|
|
||||||
i.AddFieldConfig("ID", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
})
|
|
||||||
|
|
||||||
i.AddFieldConfig("TypeHash", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
|
|
||||||
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
|
||||||
if ok == false {
|
|
||||||
return nil, fmt.Errorf("Bad resolve in params context")
|
|
||||||
}
|
|
||||||
return resolve_context, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
|
|
||||||
i := graphql.NewInterface(graphql.InterfaceConfig{
|
|
||||||
Name: "Node",
|
|
||||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
|
||||||
ctx, err := PrepTypeResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_nodes := ctx.GQLContext.ValidNodes
|
|
||||||
p_type := reflect.TypeOf(p.Value)
|
|
||||||
|
|
||||||
for key, value := range(valid_nodes) {
|
|
||||||
if p_type == key {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := p.Value.(Node)
|
|
||||||
if ok == true {
|
|
||||||
return ctx.GQLContext.BaseNodeType
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
AddNodeInterfaceFields(i)
|
|
||||||
|
|
||||||
return i
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
|
|
||||||
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
|
|
||||||
Name: "Lockable",
|
|
||||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
|
||||||
ctx, err := PrepTypeResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_lockables := ctx.GQLContext.ValidLockables
|
|
||||||
p_type := reflect.TypeOf(p.Value)
|
|
||||||
|
|
||||||
for key, value := range(valid_lockables) {
|
|
||||||
if p_type == key {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := p.Value.(*Node)
|
|
||||||
if ok == false {
|
|
||||||
return ctx.GQLContext.BaseLockableType
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
return gql_interface_lockable
|
|
||||||
}, func(lockable *graphql.Interface, lockable_list *graphql.List) {
|
|
||||||
lockable.AddFieldConfig("Requirements", &graphql.Field{
|
|
||||||
Type: lockable_list,
|
|
||||||
})
|
|
||||||
|
|
||||||
lockable.AddFieldConfig("Dependencies", &graphql.Field{
|
|
||||||
Type: lockable_list,
|
|
||||||
})
|
|
||||||
|
|
||||||
lockable.AddFieldConfig("Owner", &graphql.Field{
|
|
||||||
Type: lockable,
|
|
||||||
})
|
|
||||||
AddNodeInterfaceFields(lockable)
|
|
||||||
})
|
|
||||||
|
|
||||||
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
|
|
||||||
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
|
|
||||||
Name: "Thread",
|
|
||||||
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
|
|
||||||
ctx, err := PrepTypeResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_threads := ctx.GQLContext.ValidThreads
|
|
||||||
p_type := reflect.TypeOf(p.Value)
|
|
||||||
|
|
||||||
for key, value := range(valid_threads) {
|
|
||||||
if p_type == key {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
node, ok := p.Value.(*Node)
|
|
||||||
if ok == false {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = GetExt[*ThreadExt](node)
|
|
||||||
if err == nil {
|
|
||||||
return ctx.GQLContext.BaseThreadType
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
return gql_interface_thread
|
|
||||||
}, func(thread *graphql.Interface, thread_list *graphql.List) {
|
|
||||||
thread.AddFieldConfig("Children", &graphql.Field{
|
|
||||||
Type: thread_list,
|
|
||||||
})
|
|
||||||
|
|
||||||
thread.AddFieldConfig("Parent", &graphql.Field{
|
|
||||||
Type: thread,
|
|
||||||
})
|
|
||||||
|
|
||||||
thread.AddFieldConfig("State", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
})
|
|
||||||
|
|
||||||
AddNodeInterfaceFields(thread)
|
|
||||||
})
|
|
@ -1,114 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
)
|
|
||||||
|
|
||||||
var GQLMutationAbort = NewField(func()*graphql.Field {
|
|
||||||
gql_mutation_abort := &graphql.Field{
|
|
||||||
Type: GQLTypeSignal.Type,
|
|
||||||
Args: graphql.FieldConfigArgument{
|
|
||||||
"id": &graphql.ArgumentConfig{
|
|
||||||
Type: graphql.String,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
_, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := ExtractID(p, "id")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var node *Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLMap(
|
|
||||||
NewACLInfo(ctx.Server, []string{"children"}),
|
|
||||||
), func(context *StateContext) (error){
|
|
||||||
node, err = FindChild(context, ctx.User, ctx.Server, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if node == nil {
|
|
||||||
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
|
|
||||||
}
|
|
||||||
return SendSignal(context, node, ctx.User, AbortSignal)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return AbortSignal, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return gql_mutation_abort
|
|
||||||
})
|
|
||||||
|
|
||||||
var GQLMutationStartChild = NewField(func()*graphql.Field{
|
|
||||||
gql_mutation_start_child := &graphql.Field{
|
|
||||||
Type: GQLTypeSignal.Type,
|
|
||||||
Args: graphql.FieldConfigArgument{
|
|
||||||
"parent_id": &graphql.ArgumentConfig{
|
|
||||||
Type: graphql.String,
|
|
||||||
},
|
|
||||||
"child_id": &graphql.ArgumentConfig{
|
|
||||||
Type: graphql.String,
|
|
||||||
},
|
|
||||||
"action": &graphql.ArgumentConfig{
|
|
||||||
Type: graphql.String,
|
|
||||||
DefaultValue: "start",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
_, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
parent_id, err := ExtractID(p, "parent_id")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
child_id, err := ExtractID(p, "child_id")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
action, err := ExtractParam[string](p, "action")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var signal Signal
|
|
||||||
context := NewWriteContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLMap(
|
|
||||||
NewACLInfo(ctx.Server, []string{"children"}),
|
|
||||||
), func(context *StateContext) error {
|
|
||||||
parent, err := FindChild(context, ctx.User, ctx.Server, parent_id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if parent == nil {
|
|
||||||
return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
signal = NewStartChildSignal(child_id, action)
|
|
||||||
return SendSignal(context, ctx.User, parent, signal)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: wait for the result of the signal to send back instead of just the signal
|
|
||||||
return signal, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return gql_mutation_start_child
|
|
||||||
})
|
|
||||||
|
|
@ -0,0 +1,147 @@
|
|||||||
|
package graphvent
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
"github.com/graphql-go/graphql"
|
||||||
|
"github.com/graphql-go/graphql/language/ast"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ok := p.Source.(NodeResult)
|
||||||
|
if ok == false {
|
||||||
|
return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source))
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.NodeID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
node, ok := p.Source.(NodeResult)
|
||||||
|
if ok == false {
|
||||||
|
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint64(node.NodeType), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FieldIndex struct {
|
||||||
|
Extension ExtType
|
||||||
|
Tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFields(selection_set *ast.SelectionSet) []string {
|
||||||
|
names := []string{}
|
||||||
|
if selection_set == nil {
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sel := range(selection_set.Selections) {
|
||||||
|
switch field := sel.(type) {
|
||||||
|
case *ast.Field:
|
||||||
|
if field.Name.Value == "ID" || field.Name.Value == "Type" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, field.Name.Value)
|
||||||
|
case *ast.InlineFragment:
|
||||||
|
names = append(names, GetFields(field.SelectionSet)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the fields that need to be resolved
|
||||||
|
func GetResolveFields(p graphql.ResolveParams) []string {
|
||||||
|
fields := []string{}
|
||||||
|
for _, field := range(p.Info.FieldASTs) {
|
||||||
|
fields = append(fields, GetFields(field.SelectionSet)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
|
||||||
|
ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return NodeResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch source := p.Source.(type) {
|
||||||
|
case *StatusSignal:
|
||||||
|
cached_node, cached := ctx.NodeCache[source.Source]
|
||||||
|
if cached {
|
||||||
|
for _, field_name := range(source.Fields) {
|
||||||
|
_, cached := cached_node.Data[field_name]
|
||||||
|
if cached {
|
||||||
|
delete(cached_node.Data, field_name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.NodeCache[source.Source] = cached_node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cache, node_cached := ctx.NodeCache[id]
|
||||||
|
fields := GetResolveFields(p)
|
||||||
|
var not_cached []string
|
||||||
|
if node_cached {
|
||||||
|
not_cached = []string{}
|
||||||
|
for _, field := range(fields) {
|
||||||
|
if node_cached {
|
||||||
|
_, field_cached := cache.Data[field]
|
||||||
|
if field_cached {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
not_cached = append(not_cached, field)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
not_cached = fields
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len(not_cached) == 0) && (node_cached == true) {
|
||||||
|
ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
|
||||||
|
return cache, nil
|
||||||
|
} else {
|
||||||
|
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
|
||||||
|
|
||||||
|
signal := NewReadSignal(not_cached)
|
||||||
|
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
||||||
|
// TODO: TIMEOUT DURATION
|
||||||
|
err = ctx.Context.Send(ctx.Server, []Message{{
|
||||||
|
Node: id,
|
||||||
|
Signal: signal,
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||||
|
return NodeResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
||||||
|
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||||
|
if err != nil {
|
||||||
|
return NodeResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch response := response.(type) {
|
||||||
|
case *ReadResultSignal:
|
||||||
|
if node_cached == false {
|
||||||
|
cache = NodeResult{
|
||||||
|
NodeID: id,
|
||||||
|
NodeType: response.NodeType,
|
||||||
|
Data: response.Fields,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for field_name, field_value := range(response.Fields) {
|
||||||
|
cache.Data[field_name] = field_value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.NodeCache[id] = cache
|
||||||
|
return ctx.NodeCache[id], nil
|
||||||
|
default:
|
||||||
|
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,28 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
import (
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
)
|
|
||||||
|
|
||||||
var GQLQuerySelf = &graphql.Field{
|
|
||||||
Type: GQLTypeBaseThread.Type,
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
_, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.Server, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var GQLQueryUser = &graphql.Field{
|
|
||||||
Type: GQLTypeBaseNode.Type,
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
_, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.User, nil
|
|
||||||
},
|
|
||||||
}
|
|
@ -1,338 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
|
|
||||||
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
|
|
||||||
if ok == false {
|
|
||||||
return nil, nil, fmt.Errorf("Bad resolve in params context")
|
|
||||||
}
|
|
||||||
|
|
||||||
node, ok := p.Source.(*Node)
|
|
||||||
if ok == false {
|
|
||||||
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
|
|
||||||
}
|
|
||||||
|
|
||||||
return node, resolve_context, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
|
|
||||||
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
|
|
||||||
var zero K
|
|
||||||
arg_if, ok := p.Args[name]
|
|
||||||
if ok == false {
|
|
||||||
return zero, fmt.Errorf("No Arg of name %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
arg, ok := arg_if.(K)
|
|
||||||
if ok == false {
|
|
||||||
return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero))
|
|
||||||
}
|
|
||||||
|
|
||||||
return arg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, error) {
|
|
||||||
var zero K
|
|
||||||
|
|
||||||
arg_list, err := ExtractParam[[]interface{}](p, name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := make([]K, len(arg_list))
|
|
||||||
for i, val := range(arg_list) {
|
|
||||||
val_conv, ok := arg_list[i].(K)
|
|
||||||
if ok == false {
|
|
||||||
return nil, fmt.Errorf("Failed to cast arg %s[%d](%+v) to %+v", name, i, val, reflect.TypeOf(zero))
|
|
||||||
}
|
|
||||||
ret[i] = val_conv
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
|
|
||||||
id_str, err := ExtractParam[string](p, name)
|
|
||||||
if err != nil {
|
|
||||||
return ZeroID, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := ParseID(id_str)
|
|
||||||
if err != nil {
|
|
||||||
return ZeroID, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
|
|
||||||
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, _, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return node.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, _, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(node.Type), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
gql_ext, err := GetExt[*GQLExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
listen := ""
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
|
|
||||||
listen = gql_ext.Listen
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return listen, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var parent *Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
|
||||||
parent = thread_ext.Parent
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return parent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var state string
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
|
|
||||||
state = thread_ext.State
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var children []*Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
|
||||||
children = thread_ext.ChildList()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return children, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var requirements []*Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
|
||||||
requirements = make([]*Node, len(lockable_ext.Requirements))
|
|
||||||
i := 0
|
|
||||||
for _, req := range(lockable_ext.Requirements) {
|
|
||||||
requirements[i] = req
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return requirements, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var dependencies []*Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
|
|
||||||
dependencies = make([]*Node, len(lockable_ext.Dependencies))
|
|
||||||
i := 0
|
|
||||||
for _, dep := range(lockable_ext.Dependencies) {
|
|
||||||
dependencies[i] = dep
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return dependencies, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var owner *Node = nil
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
|
|
||||||
owner = lockable_ext.Owner
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return owner, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
node, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
group_ext, err := GetExt[*GroupExt](node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var members []*Node
|
|
||||||
context := NewReadContext(ctx.Context)
|
|
||||||
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
|
|
||||||
members = make([]*Node, len(group_ext.Members))
|
|
||||||
i := 0
|
|
||||||
for _, member := range(group_ext.Members) {
|
|
||||||
members[i] = member
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return members, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
|
||||||
if signal, ok := p.Source.(Signal); ok {
|
|
||||||
return fn(signal, p)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("Failed to cast source to event")
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
|
||||||
return signal.Type(), nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
|
||||||
direction := signal.Direction()
|
|
||||||
if direction == Up {
|
|
||||||
return "up", nil
|
|
||||||
} else if direction == Down {
|
|
||||||
return "down", nil
|
|
||||||
} else if direction == Direct {
|
|
||||||
return "direct", nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("Invalid direction: %+v", direction)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
|
|
||||||
ser, err := signal.Serialize()
|
|
||||||
return string(ser), err
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,69 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
import (
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return GQLSubscribeFn(p, false, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
|
||||||
return signal, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return GQLSubscribeFn(p, true, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
|
|
||||||
return server, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
|
|
||||||
_, ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c := make(chan interface{})
|
|
||||||
go func(c chan interface{}, ext *GQLExt, server *Node) {
|
|
||||||
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
|
|
||||||
sig_c := ext.NewSubscriptionChannel(1)
|
|
||||||
if send_nil == true {
|
|
||||||
sig_c <- nil
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
val, ok := <- sig_c
|
|
||||||
if ok == false {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ret, err := fn(ctx.Context, server, ext, val, p)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Context.Log.Logf("gqlws", "type convertor error %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c <- ret
|
|
||||||
}
|
|
||||||
}(c, ctx.Ext, ctx.Server)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var GQLSubscriptionSelf = NewField(func()*graphql.Field{
|
|
||||||
gql_subscription_self := &graphql.Field{
|
|
||||||
Type: GQLTypeBaseThread.Type,
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return p.Source, nil
|
|
||||||
},
|
|
||||||
Subscribe: GQLSubscribeSelf,
|
|
||||||
}
|
|
||||||
|
|
||||||
return gql_subscription_self
|
|
||||||
})
|
|
||||||
|
|
||||||
var GQLSubscriptionUpdate = NewField(func()*graphql.Field{
|
|
||||||
gql_subscription_update := &graphql.Field{
|
|
||||||
Type: GQLTypeSignal.Type,
|
|
||||||
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
return p.Source, nil
|
|
||||||
},
|
|
||||||
Subscribe: GQLSubscribeSignal,
|
|
||||||
}
|
|
||||||
return gql_subscription_update
|
|
||||||
})
|
|
||||||
|
|
@ -1,152 +1,223 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
"errors"
|
|
||||||
"crypto/rand"
|
"github.com/google/uuid"
|
||||||
"crypto/ecdh"
|
"golang.org/x/net/websocket"
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGQL(t * testing.T) {
|
func TestGQLSubscribe(t *testing.T) {
|
||||||
ctx := logTestContext(t, []string{"test", "db"})
|
ctx := logTestContext(t, []string{"test", "gql"})
|
||||||
|
|
||||||
TestUserNodeType := NodeType("TEST_USER")
|
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
|
||||||
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType})
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
u1 := NewNode(ctx, RandID(), TestUserNodeType)
|
listener_ext := NewListenerExt(10)
|
||||||
u1_policy := NewPerNodePolicy(NodeActions{
|
|
||||||
u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"},
|
|
||||||
})
|
|
||||||
u1.Extensions[ACLExtType] = NewACLExt(nil)
|
|
||||||
u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
|
||||||
PerNodePolicyType: &u1_policy,
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
|
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
ListenerNodeType := NodeType("LISTENER")
|
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
|
||||||
err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ListenerExtType, LockableExtType})
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
l1 := NewNode(ctx, RandID(), ListenerNodeType)
|
query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }"
|
||||||
l1_policy := NewRequirementOfPolicy(NodeActions{
|
|
||||||
l1.ID: Actions{"signal.status"},
|
|
||||||
})
|
|
||||||
|
|
||||||
l1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
ctx.Log.Logf("test", "GQL: %s", gql.ID)
|
||||||
listener_ext := NewListenerExt(10)
|
ctx.Log.Logf("test", "Node: %s", n1.ID)
|
||||||
l1.Extensions[ListenerExtType] = listener_ext
|
ctx.Log.Logf("test", "Query: %s", query)
|
||||||
l1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
|
||||||
RequirementOfPolicyType: &l1_policy,
|
|
||||||
})
|
|
||||||
l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "L1_ID: %s", l1.ID)
|
sub_1 := GQLPayload{
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
|
||||||
|
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/gql", port)
|
||||||
|
ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port)
|
||||||
|
|
||||||
TestThreadNodeType := NodeType("TEST_THREAD")
|
SubGQL := func(payload GQLPayload) {
|
||||||
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType})
|
config, err := websocket.NewConfig(ws_url, url)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
config.Protocol = append(config.Protocol, "graphql-ws")
|
||||||
|
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
|
ws, err := websocket.DialConfig(config)
|
||||||
|
|
||||||
t1 := NewNode(ctx, RandID(), TestThreadNodeType)
|
|
||||||
t1_policy := NewParentOfPolicy(NodeActions{
|
|
||||||
t1.ID: Actions{"signal.abort", "state.write"},
|
|
||||||
})
|
|
||||||
t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
|
||||||
t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
|
||||||
ParentOfPolicyType: &t1_policy,
|
|
||||||
})
|
|
||||||
t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil)
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
|
type payload_struct struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
init := struct{
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}{
|
||||||
|
uuid.New(),
|
||||||
|
"connection_init",
|
||||||
|
}
|
||||||
|
|
||||||
TestGQLNodeType := NodeType("TEST_GQL")
|
ser, err := json.Marshal(&init)
|
||||||
err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType})
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
_, err = ws.Write(ser)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
gql := NewNode(ctx, RandID(), TestGQLNodeType)
|
resp := make([]byte, 1024)
|
||||||
gql_policy := NewChildOfPolicy(NodeActions{
|
n, err := ws.Read(resp)
|
||||||
gql.ID: Actions{"signal.status"},
|
|
||||||
})
|
var init_resp GQLWSMsg
|
||||||
gql.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
|
err = json.Unmarshal(resp[:n], &init_resp)
|
||||||
gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
|
|
||||||
ChildOfPolicyType: &gql_policy,
|
|
||||||
})
|
|
||||||
gql.Extensions[GroupExtType] = NewGroupExt(nil)
|
|
||||||
gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil)
|
|
||||||
gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil)
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
|
if init_resp.Type != "connection_ack" {
|
||||||
info := ParentInfo{true, "start", "restore"}
|
t.Fatal("Didn't receive connection_ack")
|
||||||
context := NewWriteContext(ctx)
|
}
|
||||||
err = UpdateStates(context, u1, NewACLInfo(gql, []string{"users"}), func(context *StateContext) error {
|
|
||||||
err := LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{
|
sub := GQLWSMsg{
|
||||||
ParentInfoType: &info,
|
ID: uuid.New().String(),
|
||||||
}})
|
Type: "subscribe",
|
||||||
if err != nil {
|
Payload: sub_1,
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
ser, err = json.Marshal(&sub)
|
||||||
|
fatalErr(t, err)
|
||||||
|
_, err = ws.Write(ser)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
n, err = ws.Read(resp)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "SUB1: %s", resp[:n])
|
||||||
|
|
||||||
|
lock_id, err := LockLockable(ctx, gql)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(listener_ext.Chan, 100*time.Millisecond, lock_id)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
switch response.(type) {
|
||||||
|
case *SuccessSignal:
|
||||||
|
ctx.Log.Logf("test", "Locked %s", gql.ID)
|
||||||
|
default:
|
||||||
|
t.Errorf("Unexpected lock response: %s", response)
|
||||||
}
|
}
|
||||||
return LinkLockables(context, u1, l1, []*Node{gql})
|
|
||||||
})
|
n, err = ws.Read(resp)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "SUB2: %s", resp[:n])
|
||||||
|
|
||||||
|
n, err = ws.Read(resp)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "SUB3: %s", resp[:n])
|
||||||
|
|
||||||
|
// TODO: check that there are no more messages sent to ws within a timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGQL(sub_1)
|
||||||
|
}
|
||||||
|
|
||||||
context = NewReadContext(ctx)
|
func TestGQLQuery(t *testing.T) {
|
||||||
err = SendSignal(context, gql, gql, NewStatusSignal("child_linked", t1.ID))
|
ctx := logTestContext(t, []string{"test", "lockable"})
|
||||||
|
|
||||||
|
n1_listener := NewListenerExt(10)
|
||||||
|
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil), n1_listener)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
context = NewReadContext(ctx)
|
|
||||||
err = SendSignal(context, gql, gql, AbortSignal)
|
gql_listener := NewListenerExt(10)
|
||||||
|
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
err = ThreadLoop(ctx, gql, "start")
|
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener)
|
||||||
if errors.Is(err, ThreadAbortedError) == false {
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "GQL: %s", gql.ID)
|
||||||
|
ctx.Log.Logf("test", "NODE: %s", n1.ID)
|
||||||
|
|
||||||
|
skipVerifyTransport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: skipVerifyTransport}
|
||||||
|
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/gql", port)
|
||||||
|
|
||||||
|
req_1 := GQLPayload{
|
||||||
|
Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { LockableState } } }",
|
||||||
|
Variables: map[string]interface{}{
|
||||||
|
"id": n1.ID.String(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
|
req_2 := GQLPayload{
|
||||||
|
Query: "query Self { Self { ID, Type, ... on Lockable { LockableState, Requirements { Key { ID ... on Lockable { LockableState } } } } } }",
|
||||||
context = NewReadContext(ctx)
|
|
||||||
err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error {
|
|
||||||
ser1, err := gql.Serialize()
|
|
||||||
ser2, err := u1.Serialize()
|
|
||||||
ctx.Log.Logf("test", "\n%s\n\n", ser1)
|
|
||||||
ctx.Log.Logf("test", "\n%s\n\n", ser2)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
// Clear all loaded nodes from the context so it loads them from the database
|
|
||||||
ctx.Nodes = NodeMap{}
|
|
||||||
gql_loaded, err := LoadNode(ctx, gql.ID)
|
|
||||||
fatalErr(t, err)
|
|
||||||
context = NewReadContext(ctx)
|
|
||||||
err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
|
|
||||||
ser, err := gql_loaded.Serialize()
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](gql_loaded)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
ctx.Log.Logf("test", "\n%s\n\n", ser)
|
|
||||||
dependency := lockable_ext.Dependencies[l1.ID]
|
SendGQL := func(payload GQLPayload) []byte {
|
||||||
listener_ext, err = GetExt[*ListenerExt](dependency)
|
ser, err := json.MarshalIndent(&payload, "", " ")
|
||||||
if err != nil {
|
fatalErr(t, err)
|
||||||
return err
|
|
||||||
|
req_data := bytes.NewBuffer(ser)
|
||||||
|
req, err := http.NewRequest("GET", url, req_data)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
resp.Body.Close()
|
||||||
|
return body
|
||||||
}
|
}
|
||||||
SendSignal(context, gql_loaded, gql_loaded, StopSignal)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
err = ThreadLoop(ctx, gql_loaded, "start")
|
resp_1 := SendGQL(req_1)
|
||||||
|
ctx.Log.Logf("test", "RESP_1: %s", resp_1)
|
||||||
|
resp_2 := SendGQL(req_2)
|
||||||
|
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
|
||||||
|
|
||||||
|
lock_id, err := LockLockable(ctx, n1)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(n1_listener.Chan, 100*time.Millisecond, lock_id)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
|
switch response := response.(type) {
|
||||||
|
case *SuccessSignal:
|
||||||
|
default:
|
||||||
|
t.Fatalf("Wrong response: %s", reflect.TypeOf(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp_3 := SendGQL(req_1)
|
||||||
|
ctx.Log.Logf("test", "RESP_3: %s", resp_3)
|
||||||
|
|
||||||
|
resp_4 := SendGQL(req_2)
|
||||||
|
ctx.Log.Logf("test", "RESP_4: %s", resp_4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGQLDB(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test", "db", "node", "serialize"})
|
||||||
|
|
||||||
|
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
|
||||||
|
fatalErr(t, err)
|
||||||
|
listener_ext := NewListenerExt(10)
|
||||||
|
|
||||||
|
gql, err := ctx.NewNode(nil, "Node", gql_ext, listener_ext)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
|
||||||
|
|
||||||
|
err = ctx.Stop()
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
gql_loaded, err := ctx.GetNode(gql.ID)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
listener_ext, err = GetExt[ListenerExt](gql_loaded)
|
||||||
|
fatalErr(t, err)
|
||||||
|
}
|
||||||
|
@ -1,148 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/graphql-go/graphql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func AddNodeFields(obj *graphql.Object) {
|
|
||||||
obj.AddFieldConfig("ID", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLNodeID,
|
|
||||||
})
|
|
||||||
|
|
||||||
obj.AddFieldConfig("TypeHash", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLNodeTypeHash,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddLockableFields(obj *graphql.Object) {
|
|
||||||
AddNodeFields(obj)
|
|
||||||
|
|
||||||
obj.AddFieldConfig("Requirements", &graphql.Field{
|
|
||||||
Type: GQLInterfaceLockable.List,
|
|
||||||
Resolve: GQLLockableRequirements,
|
|
||||||
})
|
|
||||||
|
|
||||||
obj.AddFieldConfig("Owner", &graphql.Field{
|
|
||||||
Type: GQLInterfaceLockable.Type,
|
|
||||||
Resolve: GQLLockableOwner,
|
|
||||||
})
|
|
||||||
|
|
||||||
obj.AddFieldConfig("Dependencies", &graphql.Field{
|
|
||||||
Type: GQLInterfaceLockable.List,
|
|
||||||
Resolve: GQLLockableDependencies,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddThreadFields(obj *graphql.Object) {
|
|
||||||
AddNodeFields(obj)
|
|
||||||
|
|
||||||
obj.AddFieldConfig("State", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLThreadState,
|
|
||||||
})
|
|
||||||
|
|
||||||
obj.AddFieldConfig("Children", &graphql.Field{
|
|
||||||
Type: GQLInterfaceThread.List,
|
|
||||||
Resolve: GQLThreadChildren,
|
|
||||||
})
|
|
||||||
|
|
||||||
obj.AddFieldConfig("Parent", &graphql.Field{
|
|
||||||
Type: GQLInterfaceThread.Type,
|
|
||||||
Resolve: GQLThreadParent,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
|
|
||||||
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
|
|
||||||
Name: "SimpleThread",
|
|
||||||
Interfaces: []*graphql.Interface{
|
|
||||||
GQLInterfaceNode.Type,
|
|
||||||
GQLInterfaceThread.Type,
|
|
||||||
GQLInterfaceLockable.Type,
|
|
||||||
},
|
|
||||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
|
||||||
node, ok := p.Value.(*Node)
|
|
||||||
if ok == false {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := GetExt[*ThreadExt](node)
|
|
||||||
return err == nil
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
AddThreadFields(gql_type_simple_thread)
|
|
||||||
|
|
||||||
return gql_type_simple_thread
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
|
|
||||||
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
|
|
||||||
Name: "SimpleLockable",
|
|
||||||
Interfaces: []*graphql.Interface{
|
|
||||||
GQLInterfaceNode.Type,
|
|
||||||
GQLInterfaceLockable.Type,
|
|
||||||
},
|
|
||||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
|
||||||
node, ok := p.Value.(*Node)
|
|
||||||
if ok == false {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := GetExt[*LockableExt](node)
|
|
||||||
return err == nil
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
AddLockableFields(gql_type_simple_lockable)
|
|
||||||
|
|
||||||
return gql_type_simple_lockable
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
|
|
||||||
object := graphql.NewObject(graphql.ObjectConfig{
|
|
||||||
Name: "SimpleNode",
|
|
||||||
Interfaces: []*graphql.Interface{
|
|
||||||
GQLInterfaceNode.Type,
|
|
||||||
},
|
|
||||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
|
||||||
_, ok := p.Value.(*Node)
|
|
||||||
return ok
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
AddNodeFields(object)
|
|
||||||
|
|
||||||
return object
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
var GQLTypeSignal = NewSingleton(func() *graphql.Object {
|
|
||||||
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
|
|
||||||
Name: "Signal",
|
|
||||||
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
|
|
||||||
_, ok := p.Value.(Signal)
|
|
||||||
return ok
|
|
||||||
},
|
|
||||||
Fields: graphql.Fields{},
|
|
||||||
})
|
|
||||||
|
|
||||||
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLSignalType,
|
|
||||||
})
|
|
||||||
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLSignalDirection,
|
|
||||||
})
|
|
||||||
gql_type_signal.AddFieldConfig("String", &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Resolve: GQLSignalString,
|
|
||||||
})
|
|
||||||
return gql_type_signal
|
|
||||||
}, nil)
|
|
||||||
|
|
@ -0,0 +1,66 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Listener extension provides a channel that can receive signals on a different thread
|
||||||
|
type ListenerExt struct {
|
||||||
|
Buffer int `gv:"buffer"`
|
||||||
|
Chan chan Signal
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoadedSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLoadedSignal() *LoadedSignal {
|
||||||
|
return &LoadedSignal{
|
||||||
|
SignalHeader: NewSignalHeader(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnloadedSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUnloadedSignal() *UnloadedSignal {
|
||||||
|
return &UnloadedSignal{
|
||||||
|
SignalHeader: NewSignalHeader(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
|
||||||
|
ext.Chan = make(chan Signal, ext.Buffer)
|
||||||
|
ext.Chan <- NewLoadedSignal()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
|
||||||
|
ext.Chan <- NewUnloadedSignal()
|
||||||
|
close(ext.Chan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new listener extension with a given buffer size
|
||||||
|
func NewListenerExt(buffer int) *ListenerExt {
|
||||||
|
return &ListenerExt{
|
||||||
|
Buffer: buffer,
|
||||||
|
Chan: make(chan Signal, buffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the signal to the channel, logging an overflow if it occurs
|
||||||
|
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
||||||
|
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
|
||||||
|
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
|
||||||
|
select {
|
||||||
|
case ext.Chan <- signal:
|
||||||
|
default:
|
||||||
|
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
|
||||||
|
}
|
||||||
|
switch sig := signal.(type) {
|
||||||
|
case *StatusSignal:
|
||||||
|
ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
@ -1,596 +1,411 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/google/uuid"
|
||||||
"encoding/json"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ListenerExt struct {
|
type ReqState byte
|
||||||
Buffer int
|
const (
|
||||||
Chan chan Signal
|
Unlocked = ReqState(0)
|
||||||
}
|
Unlocking = ReqState(1)
|
||||||
|
Locked = ReqState(2)
|
||||||
func NewListenerExt(buffer int) *ListenerExt {
|
Locking = ReqState(3)
|
||||||
return &ListenerExt{
|
AbortingLock = ReqState(4)
|
||||||
Buffer: buffer,
|
)
|
||||||
Chan: make(chan Signal, buffer),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j int
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewListenerExt(j), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const ListenerExtType = ExtType("LISTENER")
|
var ReqStateStrings = map[ReqState]string {
|
||||||
func (listener *ListenerExt) Type() ExtType {
|
Unlocked: "Unlocked",
|
||||||
return ListenerExtType
|
Unlocking: "Unlocking",
|
||||||
|
Locked: "Locked",
|
||||||
|
Locking: "Locking",
|
||||||
|
AbortingLock: "AbortingLock",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error {
|
func (state ReqState) String() string {
|
||||||
select {
|
str, mapped := ReqStateStrings[state]
|
||||||
case ext.Chan <- signal:
|
if mapped == false {
|
||||||
default:
|
return "UNKNOWN_REQSTATE"
|
||||||
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
|
} else {
|
||||||
|
return str
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *ListenerExt) Serialize() ([]byte, error) {
|
|
||||||
return json.MarshalIndent(ext.Buffer, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
type LockableExt struct{
|
type LockableExt struct{
|
||||||
Owner *Node
|
State ReqState `gv:"state"`
|
||||||
Requirements map[NodeID]*Node
|
ReqID *uuid.UUID `gv:"req_id"`
|
||||||
Dependencies map[NodeID]*Node
|
Owner *NodeID `gv:"owner"`
|
||||||
LocksHeld map[NodeID]*Node
|
PendingOwner *NodeID `gv:"pending_owner"`
|
||||||
}
|
Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
|
||||||
|
|
||||||
const LockableExtType = ExtType("LOCKABLE")
|
Locked map[NodeID]any
|
||||||
func (ext *LockableExt) Type() ExtType {
|
Unlocked map[NodeID]any
|
||||||
return LockableExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
type LockableExtJSON struct {
|
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
|
||||||
Owner string `json:"owner"`
|
|
||||||
Requirements []string `json:"requirements"`
|
|
||||||
Dependencies []string `json:"dependencies"`
|
|
||||||
LocksHeld map[string]string `json:"locks_held"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) Serialize() ([]byte, error) {
|
func NewLockableExt(requirements []NodeID) *LockableExt {
|
||||||
return json.MarshalIndent(&LockableExtJSON{
|
var reqs map[NodeID]ReqState = nil
|
||||||
Owner: SaveNode(ext.Owner),
|
var unlocked map[NodeID]any = map[NodeID]any{}
|
||||||
Requirements: SaveNodeList(ext.Requirements),
|
|
||||||
Dependencies: SaveNodeList(ext.Dependencies),
|
|
||||||
LocksHeld: SaveNodeMap(ext.LocksHeld),
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
|
|
||||||
if requirements == nil {
|
|
||||||
requirements = NodeMap{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dependencies == nil {
|
if len(requirements) != 0 {
|
||||||
dependencies = NodeMap{}
|
reqs = map[NodeID]ReqState{}
|
||||||
|
for _, req := range(requirements) {
|
||||||
|
reqs[req] = Unlocked
|
||||||
|
unlocked[req] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if locks_held == nil {
|
|
||||||
locks_held = NodeMap{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LockableExt{
|
return &LockableExt{
|
||||||
Owner: owner,
|
State: Unlocked,
|
||||||
Requirements: requirements,
|
Owner: nil,
|
||||||
Dependencies: dependencies,
|
PendingOwner: nil,
|
||||||
LocksHeld: locks_held,
|
Requirements: reqs,
|
||||||
|
Waiting: WaitMap{},
|
||||||
|
|
||||||
|
Locked: map[NodeID]any{},
|
||||||
|
Unlocked: unlocked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
|
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
||||||
var j LockableExtJSON
|
signal := NewUnlockSignal()
|
||||||
err := json.Unmarshal(data, &j)
|
messages := []Message{{node.ID, signal}}
|
||||||
if err != nil {
|
return signal.ID(), ctx.Send(node, messages)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j)
|
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
|
||||||
|
signal := NewLockSignal()
|
||||||
owner, err := RestoreNode(ctx, j.Owner)
|
messages := []Message{{node.ID, signal}}
|
||||||
if err != nil {
|
return signal.ID(), ctx.Send(node, messages)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requirements, err := RestoreNodeList(ctx, j.Requirements)
|
func (ext *LockableExt) Load(ctx *Context, node *Node) error {
|
||||||
if err != nil {
|
ext.Locked = map[NodeID]any{}
|
||||||
return nil, err
|
ext.Unlocked = map[NodeID]any{}
|
||||||
}
|
|
||||||
|
|
||||||
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
|
for id, state := range(ext.Requirements) {
|
||||||
if err != nil {
|
if state == Unlocked {
|
||||||
return nil, err
|
ext.Unlocked[id] = nil
|
||||||
|
} else if state == Locked {
|
||||||
|
ext.Locked[id] = nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ext *LockableExt) Unload(ctx *Context, node *Node) {
|
||||||
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
|
// Handle link signal by adding/removing the requested NodeID
|
||||||
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
|
// returns an error if the node is not unlocked
|
||||||
|
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]Message, Changes) {
|
||||||
|
var messages []Message = nil
|
||||||
|
var changes Changes = nil
|
||||||
|
|
||||||
var err error
|
switch ext.State {
|
||||||
switch signal.Direction() {
|
case Unlocked:
|
||||||
case Up:
|
switch signal.Action {
|
||||||
err = UseStates(context, node,
|
case "add":
|
||||||
NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
|
_, exists := ext.Requirements[signal.NodeID]
|
||||||
owner_sent := false
|
if exists == true {
|
||||||
for _, dependency := range(ext.Dependencies) {
|
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_requirement")})
|
||||||
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
|
} else {
|
||||||
SendSignal(context, dependency, node, signal)
|
if ext.Requirements == nil {
|
||||||
if ext.Owner != nil {
|
ext.Requirements = map[NodeID]ReqState{}
|
||||||
if dependency.ID == ext.Owner.ID {
|
|
||||||
owner_sent = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ext.Owner != nil && owner_sent == false {
|
|
||||||
if ext.Owner.ID != node.ID {
|
|
||||||
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
|
|
||||||
return SendSignal(context, ext.Owner, node, signal)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
ext.Requirements[signal.NodeID] = Unlocked
|
||||||
})
|
changes = append(changes, "requirements")
|
||||||
case Down:
|
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
||||||
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
|
|
||||||
for _, requirement := range(ext.Requirements) {
|
|
||||||
err := SendSignal(context, requirement, node, signal)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
case "remove":
|
||||||
|
_, exists := ext.Requirements[signal.NodeID]
|
||||||
|
if exists == false {
|
||||||
|
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
|
||||||
|
} else {
|
||||||
|
delete(ext.Requirements, signal.NodeID)
|
||||||
|
changes = append(changes, "requirements")
|
||||||
|
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
case Direct:
|
|
||||||
err = nil
|
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("invalid signal direction %d", signal.Direction())
|
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unknown_action")})
|
||||||
}
|
}
|
||||||
if err != nil {
|
default:
|
||||||
return err
|
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocked: %s", ext.State)})
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
|
return messages, changes
|
||||||
last_owner, exists := ext.LocksHeld[node.ID]
|
|
||||||
if exists == false {
|
|
||||||
panic("Attempted to take a get the original lock holder of a lockable we don't own")
|
|
||||||
}
|
|
||||||
delete(ext.LocksHeld, node.ID)
|
|
||||||
return last_owner
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
|
// Handle an UnlockSignal by either transitioning to Unlocked state,
|
||||||
_, exists := ext.LocksHeld[node.ID]
|
// sending unlock signals to requirements, or returning an error signal
|
||||||
if exists == true {
|
func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source NodeID, signal *UnlockSignal) ([]Message, Changes) {
|
||||||
panic("Attempted to lock a lockable we're already holding(lock cycle)")
|
var messages []Message = nil
|
||||||
}
|
var changes Changes = nil
|
||||||
ext.LocksHeld[node.ID] = last_owner
|
|
||||||
}
|
|
||||||
|
|
||||||
// Removes requirement as a requirement from lockable
|
switch ext.State {
|
||||||
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
|
case Locked:
|
||||||
lockable_ext, err := GetExt[*LockableExt](lockable)
|
if source != *ext.Owner {
|
||||||
if err != nil {
|
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_owner")})
|
||||||
return err
|
} else {
|
||||||
}
|
if len(ext.Requirements) == 0 {
|
||||||
requirement_ext, err := GetExt[*LockableExt](requirement)
|
changes = append(changes, "state", "owner", "pending_owner")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return UpdateStates(context, princ, ACLMap{
|
|
||||||
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
|
|
||||||
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
|
|
||||||
}, func(context *StateContext) error {
|
|
||||||
var found *Node = nil
|
|
||||||
for _, req := range(lockable_ext.Requirements) {
|
|
||||||
if requirement.ID == req.ID {
|
|
||||||
found = req
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found == nil {
|
ext.Owner = nil
|
||||||
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(requirement_ext.Dependencies, lockable.ID)
|
ext.PendingOwner = nil
|
||||||
delete(lockable_ext.Requirements, requirement.ID)
|
|
||||||
|
|
||||||
return nil
|
ext.State = Unlocked
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Link requirements as requirements to lockable
|
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
||||||
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
|
} else {
|
||||||
if lockable == nil {
|
changes = append(changes, "state", "waiting", "requirements", "pending_owner")
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(requirements) == 0 {
|
ext.PendingOwner = nil
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
|
|
||||||
}
|
|
||||||
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](lockable)
|
ext.ReqID = &signal.Id
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req_exts := map[NodeID]*LockableExt{}
|
ext.State = Unlocking
|
||||||
for _, requirement := range(requirements) {
|
for id := range(ext.Requirements) {
|
||||||
if requirement == nil {
|
unlock_signal := NewUnlockSignal()
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
|
|
||||||
}
|
|
||||||
|
|
||||||
if lockable.ID == requirement.ID {
|
ext.Waiting[unlock_signal.Id] = id
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
|
ext.Requirements[id] = Unlocking
|
||||||
}
|
|
||||||
|
|
||||||
_, exists := req_exts[requirement.ID]
|
messages = append(messages, Message{id, unlock_signal})
|
||||||
if exists == true {
|
}
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
|
|
||||||
}
|
}
|
||||||
ext, err := GetExt[*LockableExt](requirement)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
req_exts[requirement.ID] = ext
|
default:
|
||||||
|
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_locked")})
|
||||||
}
|
}
|
||||||
|
|
||||||
return UpdateStates(context, princ, NewACLMap(
|
return messages, changes
|
||||||
NewACLInfo(lockable, []string{"requirements"}),
|
|
||||||
ACLList(requirements, []string{"dependencies"}),
|
|
||||||
), func(context *StateContext) error {
|
|
||||||
// Check that all the requirements can be added
|
|
||||||
// If the lockable is already locked, need to lock this resource as well before we can add it
|
|
||||||
for _, requirement := range(requirements) {
|
|
||||||
requirement_ext := req_exts[requirement.ID]
|
|
||||||
for _, req := range(requirements) {
|
|
||||||
if req.ID == requirement.ID {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
|
// Handle a LockSignal by either transitioning to a locked state,
|
||||||
if err != nil {
|
// sending lock signals to requirements, or returning an error signal
|
||||||
return err
|
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]Message, Changes) {
|
||||||
} else if is_req {
|
var messages []Message = nil
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
|
var changes Changes = nil
|
||||||
|
|
||||||
}
|
switch ext.State {
|
||||||
}
|
case Unlocked:
|
||||||
|
if len(ext.Requirements) == 0 {
|
||||||
|
changes = append(changes, "state", "owner", "pending_owner")
|
||||||
|
|
||||||
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
|
ext.Owner = &source
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if is_req {
|
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
|
ext.PendingOwner = &source
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if is_req {
|
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if lockable_ext.Owner == nil {
|
ext.State = Locked
|
||||||
// If the new owner isn't locked, we can add the requirement
|
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
|
||||||
} else if requirement_ext.Owner == nil {
|
|
||||||
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
|
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
|
|
||||||
} else {
|
} else {
|
||||||
// If the new requirement is already locked and the owner is already locked, their owners need to match
|
changes = append(changes, "state", "requirements", "waiting", "pending_owner")
|
||||||
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
|
|
||||||
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Update the states of the requirements
|
|
||||||
for _, requirement := range(requirements) {
|
|
||||||
requirement_ext := req_exts[requirement.ID]
|
|
||||||
requirement_ext.Dependencies[lockable.ID] = lockable
|
|
||||||
lockable_ext.Requirements[lockable.ID] = requirement
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return no error
|
ext.PendingOwner = &source
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
|
ext.ReqID = &signal.Id
|
||||||
for _, req := range(cur.Requirements) {
|
|
||||||
if req.ID == id {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
req_ext, err := GetExt[*LockableExt](req)
|
ext.State = Locking
|
||||||
if err != nil {
|
for id := range(ext.Requirements) {
|
||||||
return false, err
|
lock_signal := NewLockSignal()
|
||||||
}
|
|
||||||
|
ext.Waiting[lock_signal.Id] = id
|
||||||
|
ext.Requirements[id] = Locking
|
||||||
|
|
||||||
var is_req bool
|
messages = append(messages, Message{id, lock_signal})
|
||||||
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
|
|
||||||
is_req, err = checkIfRequirement(context, id, req_ext)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
}
|
||||||
if is_req == true {
|
|
||||||
return true, nil
|
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_unlocked: %s", ext.State)})
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return messages, changes
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
|
// Handle an error signal by aborting the lock, or retrying the unlock
|
||||||
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
|
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]Message, Changes) {
|
||||||
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
|
var messages []Message = nil
|
||||||
if to_lock == nil {
|
var changes Changes = nil
|
||||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
req_exts := map[NodeID]*LockableExt{}
|
id, waiting := ext.Waiting[signal.ReqID]
|
||||||
for _, l := range(to_lock) {
|
if waiting == true {
|
||||||
var err error
|
delete(ext.Waiting, signal.ReqID)
|
||||||
if l == nil {
|
changes = append(changes, "waiting")
|
||||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
switch ext.State {
|
||||||
if err != nil {
|
case Locking:
|
||||||
return err
|
changes = append(changes, "state", "requirements")
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if new_owner == nil {
|
ext.Requirements[id] = Unlocked
|
||||||
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
|
|
||||||
}
|
|
||||||
|
|
||||||
new_owner_ext, err := GetExt[*LockableExt](new_owner)
|
unlocked := 0
|
||||||
if err != nil {
|
for req_id, req_state := range(ext.Requirements) {
|
||||||
return err
|
// Unlock locked requirements, and count unlocked requirements
|
||||||
}
|
switch req_state {
|
||||||
|
case Locked:
|
||||||
|
unlock_signal := NewUnlockSignal()
|
||||||
|
|
||||||
// Called with no requirements to lock, success
|
ext.Waiting[unlock_signal.Id] = req_id
|
||||||
if len(to_lock) == 0 {
|
ext.Requirements[req_id] = Unlocking
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return UpdateStates(context, new_owner, NewACLMap(
|
messages = append(messages, Message{req_id, unlock_signal})
|
||||||
ACLListM(to_lock, []string{"lock"}),
|
case Unlocked:
|
||||||
NewACLInfo(new_owner, nil),
|
unlocked += 1
|
||||||
), func(context *StateContext) error {
|
|
||||||
// First loop is to check that the states can be locked, and locks all requirements
|
|
||||||
for _, req := range(to_lock) {
|
|
||||||
req_ext := req_exts[req.ID]
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
|
|
||||||
|
|
||||||
// If req is alreay locked, check that we can pass the lock
|
|
||||||
if req_ext.Owner != nil {
|
|
||||||
owner := req_ext.Owner
|
|
||||||
if owner.ID == new_owner.ID {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
|
|
||||||
return LockLockables(context, req_ext.Requirements, req)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if unlocked == len(ext.Requirements) {
|
||||||
|
changes = append(changes, "owner", "state")
|
||||||
|
ext.State = Unlocked
|
||||||
|
ext.Owner = nil
|
||||||
} else {
|
} else {
|
||||||
err := LockLockables(context, req_ext.Requirements, req)
|
changes = append(changes, "state")
|
||||||
if err != nil {
|
ext.State = AbortingLock
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point state modification will be started, so no errors can be returned
|
|
||||||
for _, req := range(to_lock) {
|
|
||||||
req_ext := req_exts[req.ID]
|
|
||||||
old_owner := req_ext.Owner
|
|
||||||
// If the lockable was previously unowned, update the state
|
|
||||||
if old_owner == nil {
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
|
|
||||||
req_ext.Owner = new_owner
|
|
||||||
new_owner_ext.RecordLock(req, old_owner)
|
|
||||||
// Otherwise if the new owner already owns it, no need to update state
|
|
||||||
} else if old_owner.ID == new_owner.ID {
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
|
|
||||||
// Otherwise update the state
|
|
||||||
} else {
|
|
||||||
req_ext.Owner = new_owner
|
|
||||||
new_owner_ext.RecordLock(req, old_owner)
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
case Unlocking:
|
||||||
|
unlock_signal := NewUnlockSignal()
|
||||||
|
ext.Waiting[unlock_signal.Id] = id
|
||||||
|
messages = append(messages, Message{id, unlock_signal})
|
||||||
|
|
||||||
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
|
case AbortingLock:
|
||||||
if to_unlock == nil {
|
req_state := ext.Requirements[id]
|
||||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
|
// Mark failed lock as Unlocked, or retry unlock
|
||||||
}
|
switch req_state {
|
||||||
|
case Locking:
|
||||||
|
ext.Requirements[id] = Unlocked
|
||||||
|
|
||||||
req_exts := map[NodeID]*LockableExt{}
|
// Check if all requirements unlocked now
|
||||||
for _, l := range(to_unlock) {
|
unlocked := 0
|
||||||
if l == nil {
|
for _, req_state := range(ext.Requirements) {
|
||||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
|
if req_state == Unlocked {
|
||||||
|
unlocked += 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
if unlocked == len(ext.Requirements) {
|
||||||
req_exts[l.ID], err = GetExt[*LockableExt](l)
|
changes = append(changes, "owner", "state")
|
||||||
if err != nil {
|
ext.State = Unlocked
|
||||||
return err
|
ext.Owner = nil
|
||||||
}
|
}
|
||||||
|
case Unlocking:
|
||||||
|
// Handle error for unlocking requirement while unlocking by retrying unlock
|
||||||
|
unlock_signal := NewUnlockSignal()
|
||||||
|
ext.Waiting[unlock_signal.Id] = id
|
||||||
|
messages = append(messages, Message{id, unlock_signal})
|
||||||
}
|
}
|
||||||
|
|
||||||
if old_owner == nil {
|
|
||||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
old_owner_ext, err := GetExt[*LockableExt](old_owner)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return messages, changes
|
||||||
// Called with no requirements to unlock, success
|
|
||||||
if len(to_unlock) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return UpdateStates(context, old_owner, NewACLMap(
|
// Handle a success signal by checking if all requirements have been locked/unlocked
|
||||||
ACLListM(to_unlock, []string{"lock"}),
|
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]Message, Changes) {
|
||||||
NewACLInfo(old_owner, nil),
|
var messages []Message = nil
|
||||||
), func(context *StateContext) error {
|
var changes Changes = nil
|
||||||
// First loop is to check that the states can be locked, and locks all requirements
|
|
||||||
for _, req := range(to_unlock) {
|
|
||||||
req_ext := req_exts[req.ID]
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
|
|
||||||
|
|
||||||
// Check if the owner is correct
|
id, waiting := ext.Waiting[signal.ReqID]
|
||||||
if req_ext.Owner != nil {
|
if waiting == true {
|
||||||
if req_ext.Owner.ID != old_owner.ID {
|
delete(ext.Waiting, signal.ReqID)
|
||||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
|
changes = append(changes, "waiting")
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := UnlockLockables(context, req_ext.Requirements, req)
|
switch ext.State {
|
||||||
if err != nil {
|
case Locking:
|
||||||
return err
|
ext.Requirements[id] = Locked
|
||||||
}
|
ext.Locked[id] = nil
|
||||||
}
|
delete(ext.Unlocked, id)
|
||||||
|
|
||||||
// At this point state modification will be started, so no errors can be returned
|
if len(ext.Locked) == len(ext.Requirements) {
|
||||||
for _, req := range(to_unlock) {
|
ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked))
|
||||||
req_ext := req_exts[req.ID]
|
changes = append(changes, "state", "owner", "req_id")
|
||||||
new_owner := old_owner_ext.RecordUnlock(req)
|
ext.State = Locked
|
||||||
req_ext.Owner = new_owner
|
|
||||||
if new_owner == nil {
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
|
|
||||||
} else {
|
|
||||||
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
ext.Owner = ext.PendingOwner
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func SaveNode(node *Node) string {
|
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
||||||
str := ""
|
ext.ReqID = nil
|
||||||
if node != nil {
|
} else {
|
||||||
str = node.ID.String()
|
ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements))
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
}
|
||||||
|
case AbortingLock:
|
||||||
|
req_state := ext.Requirements[id]
|
||||||
|
switch req_state {
|
||||||
|
case Locking:
|
||||||
|
ext.Requirements[id] = Unlocking
|
||||||
|
unlock_signal := NewUnlockSignal()
|
||||||
|
ext.Waiting[unlock_signal.Id] = id
|
||||||
|
messages = append(messages, Message{id, unlock_signal})
|
||||||
|
case Unlocking:
|
||||||
|
ext.Requirements[id] = Unlocked
|
||||||
|
ext.Unlocked[id] = nil
|
||||||
|
delete(ext.Locked, id)
|
||||||
|
|
||||||
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
|
unlocked := 0
|
||||||
if id_str == "" {
|
for _, req_state := range(ext.Requirements) {
|
||||||
return nil, nil
|
switch req_state {
|
||||||
|
case Unlocked:
|
||||||
|
unlocked += 1
|
||||||
}
|
}
|
||||||
id, err := ParseID(id_str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return LoadNode(ctx, id)
|
if unlocked == len(ext.Requirements) {
|
||||||
}
|
changes = append(changes, "state", "pending_owner", "req_id")
|
||||||
|
|
||||||
func SaveNodeMap(nodes NodeMap) map[string]string {
|
messages = append(messages, Message{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)})
|
||||||
m := map[string]string{}
|
ext.State = Unlocked
|
||||||
for id, node := range(nodes) {
|
ext.ReqID = nil
|
||||||
m[id.String()] = SaveNode(node)
|
ext.PendingOwner = nil
|
||||||
}
|
}
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
|
|
||||||
nodes := NodeMap{}
|
|
||||||
for id_str_1, id_str_2 := range(ids) {
|
|
||||||
id_1, err := ParseID(id_str_1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
node_1, err := LoadNode(ctx, id_1)
|
case Unlocking:
|
||||||
if err != nil {
|
ext.Requirements[id] = Unlocked
|
||||||
return nil, err
|
ext.Unlocked[id] = Unlocked
|
||||||
}
|
delete(ext.Locked, id)
|
||||||
|
|
||||||
|
if len(ext.Unlocked) == len(ext.Requirements) {
|
||||||
|
changes = append(changes, "state", "owner", "req_id")
|
||||||
|
|
||||||
var node_2 *Node = nil
|
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
|
||||||
if id_str_2 != "" {
|
ext.State = Unlocked
|
||||||
id_2, err := ParseID(id_str_2)
|
ext.ReqID = nil
|
||||||
if err != nil {
|
ext.Owner = nil
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
node_2, err = LoadNode(ctx, id_2)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes[node_1.ID] = node_2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return messages, changes
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveNodeList(nodes NodeMap) []string {
|
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
|
||||||
ids := make([]string, len(nodes))
|
var messages []Message = nil
|
||||||
i := 0
|
var changes Changes = nil
|
||||||
for id, _ := range(nodes) {
|
|
||||||
ids[i] = id.String()
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return ids
|
switch sig := signal.(type) {
|
||||||
|
case *StatusSignal:
|
||||||
|
// Forward StatusSignals up to the owner(unless that would be a cycle)
|
||||||
|
if ext.Owner != nil {
|
||||||
|
if *ext.Owner != node.ID {
|
||||||
|
messages = append(messages, Message{*ext.Owner, signal})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
|
|
||||||
nodes := NodeMap{}
|
|
||||||
|
|
||||||
for _, id_str := range(ids) {
|
|
||||||
node, err := RestoreNode(ctx, id_str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
nodes[node.ID] = node
|
case *LinkSignal:
|
||||||
|
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
|
||||||
|
case *LockSignal:
|
||||||
|
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
|
||||||
|
case *UnlockSignal:
|
||||||
|
messages, changes = ext.HandleUnlockSignal(ctx, node, source, sig)
|
||||||
|
case *ErrorSignal:
|
||||||
|
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
|
||||||
|
case *SuccessSignal:
|
||||||
|
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return messages, changes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,148 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLink(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"lockable", "listener"})
|
||||||
|
|
||||||
|
|
||||||
|
l2_listener := NewListenerExt(10)
|
||||||
|
l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
l1_lockable := NewLockableExt(nil)
|
||||||
|
l1_listener := NewListenerExt(10)
|
||||||
|
l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
link_signal := NewLinkSignal("add", l2.ID)
|
||||||
|
msgs := []Message{{l1.ID, link_signal}}
|
||||||
|
err = ctx.Send(l1, msgs)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
state, exists := l1_lockable.Requirements[l2.ID]
|
||||||
|
if exists == false {
|
||||||
|
t.Fatal("l2 not in l1 requirements")
|
||||||
|
} else if state != Unlocked {
|
||||||
|
t.Fatalf("l2 in bad requirement state in l1: %+v", state)
|
||||||
|
}
|
||||||
|
|
||||||
|
unlink_signal := NewLinkSignal("remove", l2.ID)
|
||||||
|
msgs = []Message{{l1.ID, unlink_signal}}
|
||||||
|
err = ctx.Send(l1, msgs)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
|
||||||
|
fatalErr(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test10Lock(t *testing.T) {
|
||||||
|
testLockN(t, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test100Lock(t *testing.T) {
|
||||||
|
testLockN(t, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test1000Lock(t *testing.T) {
|
||||||
|
testLockN(t, 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test10000Lock(t *testing.T) {
|
||||||
|
testLockN(t, 10000)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLockN(t *testing.T, n int) {
|
||||||
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
|
NewLockable := func()(*Node) {
|
||||||
|
l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
|
||||||
|
fatalErr(t, err)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
reqs := make([]NodeID, n)
|
||||||
|
for i := range(reqs) {
|
||||||
|
new_lockable := NewLockable()
|
||||||
|
reqs[i] = new_lockable.ID
|
||||||
|
}
|
||||||
|
ctx.Log.Logf("test", "CREATED_%d", n)
|
||||||
|
|
||||||
|
listener := NewListenerExt(50000)
|
||||||
|
node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "CREATED_LISTENER")
|
||||||
|
|
||||||
|
lock_id, err := LockLockable(ctx, node)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(listener.Chan, time.Second*60, lock_id)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
switch resp := response.(type) {
|
||||||
|
case *SuccessSignal:
|
||||||
|
default:
|
||||||
|
t.Fatalf("Unexpected response to lock - %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "LOCKED_%d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLock(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test", "lockable"})
|
||||||
|
|
||||||
|
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
|
||||||
|
listener := NewListenerExt(10000)
|
||||||
|
l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
|
||||||
|
fatalErr(t, err)
|
||||||
|
return l, listener
|
||||||
|
}
|
||||||
|
|
||||||
|
l2, _ := NewLockable(nil)
|
||||||
|
l3, _ := NewLockable(nil)
|
||||||
|
l4, _ := NewLockable(nil)
|
||||||
|
l5, _ := NewLockable(nil)
|
||||||
|
l0, l0_listener := NewLockable([]NodeID{l5.ID})
|
||||||
|
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "l0: %s", l0.ID)
|
||||||
|
ctx.Log.Logf("test", "l1: %s", l1.ID)
|
||||||
|
ctx.Log.Logf("test", "l2: %s", l2.ID)
|
||||||
|
ctx.Log.Logf("test", "l3: %s", l3.ID)
|
||||||
|
ctx.Log.Logf("test", "l4: %s", l4.ID)
|
||||||
|
ctx.Log.Logf("test", "l5: %s", l5.ID)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "locking l0")
|
||||||
|
id_1, err := LockLockable(ctx, l0)
|
||||||
|
fatalErr(t, err)
|
||||||
|
response, _, err := WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "l0 lock: %+v", response)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "locking l1")
|
||||||
|
id_2, err := LockLockable(ctx, l1)
|
||||||
|
fatalErr(t, err)
|
||||||
|
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "unlocking l0")
|
||||||
|
id_3, err := UnlockLockable(ctx, l0)
|
||||||
|
fatalErr(t, err)
|
||||||
|
response, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "l0 unlock: %+v", response)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "locking l1")
|
||||||
|
id_4, err := LockLockable(ctx, l1)
|
||||||
|
fatalErr(t, err)
|
||||||
|
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
|
||||||
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "l1 lock: %+v", response)
|
||||||
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Node NodeID
|
||||||
|
Signal Signal
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageQueue struct {
|
||||||
|
out chan<- Message
|
||||||
|
in <-chan Message
|
||||||
|
buffer []Message
|
||||||
|
write_cursor int
|
||||||
|
read_cursor int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (queue *MessageQueue) ProcessIncoming(message Message) {
|
||||||
|
if (queue.write_cursor + 1) == queue.read_cursor || ((queue.write_cursor + 1) == len(queue.buffer) && queue.read_cursor == 0) {
|
||||||
|
new_buffer := make([]Message, len(queue.buffer) * 2)
|
||||||
|
|
||||||
|
copy(new_buffer, queue.buffer[queue.read_cursor:])
|
||||||
|
first_chunk := len(queue.buffer) - queue.read_cursor
|
||||||
|
copy(new_buffer[first_chunk:], queue.buffer[0:queue.write_cursor])
|
||||||
|
|
||||||
|
queue.write_cursor = len(queue.buffer) - 1
|
||||||
|
queue.read_cursor = 0
|
||||||
|
queue.buffer = new_buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.buffer[queue.write_cursor] = message
|
||||||
|
queue.write_cursor += 1
|
||||||
|
if queue.write_cursor >= len(queue.buffer) {
|
||||||
|
queue.write_cursor = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessageQueue(initial int) (chan<- Message, <-chan Message) {
|
||||||
|
in := make(chan Message, 0)
|
||||||
|
out := make(chan Message, 0)
|
||||||
|
|
||||||
|
queue := MessageQueue{
|
||||||
|
out: out,
|
||||||
|
in: in,
|
||||||
|
buffer: make([]Message, initial),
|
||||||
|
write_cursor: 0,
|
||||||
|
read_cursor: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(queue *MessageQueue) {
|
||||||
|
for true {
|
||||||
|
if queue.write_cursor != queue.read_cursor {
|
||||||
|
select {
|
||||||
|
case incoming := <-queue.in:
|
||||||
|
queue.ProcessIncoming(incoming)
|
||||||
|
case queue.out <- queue.buffer[queue.read_cursor]:
|
||||||
|
queue.read_cursor += 1
|
||||||
|
if queue.read_cursor >= len(queue.buffer) {
|
||||||
|
queue.read_cursor = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message := <-queue.in
|
||||||
|
queue.ProcessIncoming(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(&queue)
|
||||||
|
|
||||||
|
return in, out
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func sendBatch(start, end uint64, in chan<- Message) {
|
||||||
|
for i := start; i <= end; i++ {
|
||||||
|
var id NodeID
|
||||||
|
binary.BigEndian.PutUint64(id[:], i)
|
||||||
|
in <- Message{id, nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageQueue(t *testing.T) {
|
||||||
|
in, out := NewMessageQueue(10)
|
||||||
|
|
||||||
|
for i := uint64(0); i < 1000; i++ {
|
||||||
|
go sendBatch(1000*i, (1000*(i+1))-1, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := map[NodeID]any{}
|
||||||
|
for i := uint64(0); i < 1000*1000; i++ {
|
||||||
|
read := <-out
|
||||||
|
_, already_seen := seen[read.Node]
|
||||||
|
if already_seen {
|
||||||
|
t.Fatalf("Signal %d had duplicate NodeID %s", i, read.Node)
|
||||||
|
} else {
|
||||||
|
seen[read.Node] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Processed 1M signals through queue")
|
||||||
|
}
|
@ -1,396 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Policy interface {
|
|
||||||
Serializable[PolicyType]
|
|
||||||
Allows(context *StateContext, principal *Node, action string, node *Node) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
|
|
||||||
type RequirementOfPolicy struct {
|
|
||||||
PerNodePolicy
|
|
||||||
}
|
|
||||||
func (policy *RequirementOfPolicy) Type() PolicyType {
|
|
||||||
return RequirementOfPolicyType
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy {
|
|
||||||
return RequirementOfPolicy{
|
|
||||||
PerNodePolicy: NewPerNodePolicy(nodes),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any of principals dependencies are in the policy
|
|
||||||
func (policy *RequirementOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
|
||||||
lockable_ext, err := GetExt[*LockableExt](principal)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for dep_id, _ := range(lockable_ext.Dependencies) {
|
|
||||||
for node_id, actions := range(policy.NodeActions) {
|
|
||||||
if node_id == dep_id {
|
|
||||||
if actions.Allows(action) == true {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
const ChildOfPolicyType = PolicyType("CHILD_OF")
|
|
||||||
type ChildOfPolicy struct {
|
|
||||||
PerNodePolicy
|
|
||||||
}
|
|
||||||
func (policy *ChildOfPolicy) Type() PolicyType {
|
|
||||||
return ChildOfPolicyType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
|
||||||
context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy)
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](principal)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parent := thread_ext.Parent
|
|
||||||
if parent != nil {
|
|
||||||
actions, exists := policy.NodeActions[parent.ID]
|
|
||||||
if exists == false {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, a := range(actions) {
|
|
||||||
if a == action {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type Actions []string
|
|
||||||
|
|
||||||
func (actions Actions) Allows(action string) bool {
|
|
||||||
for _, a := range(actions) {
|
|
||||||
if a == action {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type NodeActions map[NodeID]Actions
|
|
||||||
|
|
||||||
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
|
|
||||||
return func(ctx *Context, data []byte)(Policy, error){
|
|
||||||
var j PerNodePolicyJSON
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
node_actions := NodeActions{}
|
|
||||||
for id_str, actions := range(j.NodeActions) {
|
|
||||||
id, err := ParseID(id_str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = LoadNode(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
node_actions[id] = actions
|
|
||||||
}
|
|
||||||
|
|
||||||
return init_fn(node_actions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChildOfPolicy(node_actions NodeActions) ChildOfPolicy {
|
|
||||||
return ChildOfPolicy{
|
|
||||||
PerNodePolicy: NewPerNodePolicy(node_actions),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParentOfPolicyType = PolicyType("PARENT_OF")
|
|
||||||
type ParentOfPolicy struct {
|
|
||||||
PerNodePolicy
|
|
||||||
}
|
|
||||||
func (policy *ParentOfPolicy) Type() PolicyType {
|
|
||||||
return ParentOfPolicyType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
|
||||||
context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy)
|
|
||||||
for id, actions := range(policy.NodeActions) {
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id])
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent)
|
|
||||||
if thread_ext.Parent != nil {
|
|
||||||
if thread_ext.Parent.ID == principal.ID {
|
|
||||||
for _, a := range(actions) {
|
|
||||||
if a == action {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy {
|
|
||||||
return ParentOfPolicy{
|
|
||||||
PerNodePolicy: NewPerNodePolicy(node_actions),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
|
|
||||||
if node_actions == nil {
|
|
||||||
node_actions = NodeActions{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return PerNodePolicy{
|
|
||||||
NodeActions: node_actions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type PerNodePolicy struct {
|
|
||||||
NodeActions NodeActions
|
|
||||||
}
|
|
||||||
|
|
||||||
type PerNodePolicyJSON struct {
|
|
||||||
NodeActions map[string][]string `json:"node_actions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const PerNodePolicyType = PolicyType("PER_NODE")
|
|
||||||
func (policy *PerNodePolicy) Type() PolicyType {
|
|
||||||
return PerNodePolicyType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
|
|
||||||
node_actions := map[string][]string{}
|
|
||||||
for id, actions := range(policy.NodeActions) {
|
|
||||||
node_actions[id.String()] = actions
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.MarshalIndent(&PerNodePolicyJSON{
|
|
||||||
NodeActions: node_actions,
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
|
||||||
for id, actions := range(policy.NodeActions) {
|
|
||||||
if id != principal.ID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, a := range(actions) {
|
|
||||||
if a == action {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// Extension to allow a node to hold ACL policies
|
|
||||||
type ACLPolicyExt struct {
|
|
||||||
Policies map[PolicyType]Policy
|
|
||||||
}
|
|
||||||
|
|
||||||
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
|
|
||||||
type ACLExt struct {
|
|
||||||
Delegations NodeMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j struct {
|
|
||||||
Delegations []string `json:"delegation"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
delegations, err := RestoreNodeList(ctx, j.Delegations)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ACLExt{
|
|
||||||
Delegations: delegations,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodeList(nodes ...*Node) NodeMap {
|
|
||||||
m := NodeMap{}
|
|
||||||
for _, node := range(nodes) {
|
|
||||||
m[node.ID] = node
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLExt(delegations NodeMap) *ACLExt {
|
|
||||||
if delegations == nil {
|
|
||||||
delegations = NodeMap{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ACLExt{
|
|
||||||
Delegations: delegations,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLExt) Serialize() ([]byte, error) {
|
|
||||||
delegations := make([]string, len(ext.Delegations))
|
|
||||||
i := 0
|
|
||||||
for id, _ := range(ext.Delegations) {
|
|
||||||
delegations[i] = id.String()
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.MarshalIndent(&struct{
|
|
||||||
Delegations []string `json:"delegations"`
|
|
||||||
}{
|
|
||||||
Delegations: delegations,
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
const ACLExtType = ExtType("ACL")
|
|
||||||
func (ext *ACLExt) Type() ExtType {
|
|
||||||
return ACLExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
|
|
||||||
type PolicyInfo struct {
|
|
||||||
Load PolicyLoadFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
type ACLPolicyExtContext struct {
|
|
||||||
Types map[PolicyType]PolicyInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLPolicyExtContext() *ACLPolicyExtContext {
|
|
||||||
return &ACLPolicyExtContext{
|
|
||||||
Types: map[PolicyType]PolicyInfo{
|
|
||||||
PerNodePolicyType: PolicyInfo{
|
|
||||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
|
||||||
policy := NewPerNodePolicy(nodes)
|
|
||||||
return &policy, nil
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
ParentOfPolicyType: PolicyInfo{
|
|
||||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
|
||||||
policy := NewParentOfPolicy(nodes)
|
|
||||||
return &policy, nil
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
ChildOfPolicyType: PolicyInfo{
|
|
||||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
|
||||||
policy := NewChildOfPolicy(nodes)
|
|
||||||
return &policy, nil
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
RequirementOfPolicyType: PolicyInfo{
|
|
||||||
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
|
|
||||||
policy := NewRequirementOfPolicy(nodes)
|
|
||||||
return &policy, nil
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLPolicyExt) Serialize() ([]byte, error) {
|
|
||||||
policies := map[string][]byte{}
|
|
||||||
for name, policy := range(ext.Policies) {
|
|
||||||
ser, err := policy.Serialize()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
policies[string(name)] = ser
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.MarshalIndent(&struct{
|
|
||||||
Policies map[string][]byte `json:"policies"`
|
|
||||||
}{
|
|
||||||
Policies: policies,
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt {
|
|
||||||
if policies == nil {
|
|
||||||
policies = map[PolicyType]Policy{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ACLPolicyExt{
|
|
||||||
Policies: policies,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j struct {
|
|
||||||
Policies map[string][]byte `json:"policies"`
|
|
||||||
}
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
policies := map[PolicyType]Policy{}
|
|
||||||
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(*ACLPolicyExtContext)
|
|
||||||
for name, ser := range(j.Policies) {
|
|
||||||
policy_def, exists := acl_ctx.Types[PolicyType(name)]
|
|
||||||
if exists == false {
|
|
||||||
return nil, fmt.Errorf("%s is not a known policy type", name)
|
|
||||||
}
|
|
||||||
policy, err := policy_def.Load(ctx, ser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
policies[PolicyType(name)] = policy
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewACLPolicyExt(policies), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const ACLPolicyExtType = ExtType("ACL_POLICIES")
|
|
||||||
func (ext *ACLPolicyExt) Type() ExtType {
|
|
||||||
return ACLPolicyExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the extension allows the principal to perform action on node
|
|
||||||
func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
|
|
||||||
context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
|
|
||||||
for _, policy := range(ext.Policies) {
|
|
||||||
context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy)
|
|
||||||
if policy.Allows(context, principal, action, node) == true {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
@ -0,0 +1,744 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SerializedType uint64
|
||||||
|
|
||||||
|
func (t SerializedType) String() string {
|
||||||
|
return fmt.Sprintf("0x%x", uint64(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExtType SerializedType
|
||||||
|
|
||||||
|
func (t ExtType) String() string {
|
||||||
|
return fmt.Sprintf("0x%x", uint64(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
type NodeType SerializedType
|
||||||
|
|
||||||
|
func (t NodeType) String() string {
|
||||||
|
return fmt.Sprintf("0x%x", uint64(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
type SignalType SerializedType
|
||||||
|
|
||||||
|
func (t SignalType) String() string {
|
||||||
|
return fmt.Sprintf("0x%x", uint64(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
type FieldTag SerializedType
|
||||||
|
|
||||||
|
func (t FieldTag) String() string {
|
||||||
|
return fmt.Sprintf("0x%x", uint64(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeTypeFor(name string) NodeType {
|
||||||
|
digest := []byte("GRAPHVENT_NODE - " + name)
|
||||||
|
|
||||||
|
hash := sha512.Sum512(digest)
|
||||||
|
return NodeType(binary.BigEndian.Uint64(hash[0:8]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SerializeType(t fmt.Stringer) SerializedType {
|
||||||
|
digest := []byte(t.String())
|
||||||
|
hash := sha512.Sum512(digest)
|
||||||
|
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SerializedTypeFor[T any]() SerializedType {
|
||||||
|
return SerializeType(reflect.TypeFor[T]())
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType {
|
||||||
|
return ExtType(SerializedTypeFor[E]())
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtTypeOf(t reflect.Type) ExtType {
|
||||||
|
return ExtType(SerializeType(t.Elem()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignalTypeFor[S Signal]() SignalType {
|
||||||
|
return SignalType(SerializedTypeFor[S]())
|
||||||
|
}
|
||||||
|
|
||||||
|
func Hash(base, data string) SerializedType {
|
||||||
|
digest := []byte(base + ":" + data)
|
||||||
|
hash := sha512.Sum512(digest)
|
||||||
|
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFieldTag(tag string) FieldTag {
|
||||||
|
return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TypeStack(ctx *Context, t reflect.Type, data []byte) (int, error) {
|
||||||
|
info, registered := ctx.Types[t]
|
||||||
|
if registered {
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(info.Serialized))
|
||||||
|
return 8, nil
|
||||||
|
} else {
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Map)))
|
||||||
|
|
||||||
|
key_written, err := TypeStack(ctx, t.Key(), data[8:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
elem_written, err := TypeStack(ctx, t.Elem(), data[8 + key_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return 8 + key_written + elem_written, nil
|
||||||
|
|
||||||
|
case reflect.Pointer:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Pointer)))
|
||||||
|
|
||||||
|
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return 8 + elem_written, nil
|
||||||
|
case reflect.Slice:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Slice)))
|
||||||
|
|
||||||
|
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return 8 + elem_written, nil
|
||||||
|
case reflect.Array:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Array)))
|
||||||
|
binary.BigEndian.PutUint64(data[8:], uint64(t.Len()))
|
||||||
|
|
||||||
|
elem_written, err := TypeStack(ctx, t.Elem(), data[16:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return 16 + elem_written, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("Hit %s, which is not a registered type", t.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
|
||||||
|
first_bytes, left := split(stack, 8)
|
||||||
|
first := SerializedType(binary.BigEndian.Uint64(first_bytes))
|
||||||
|
|
||||||
|
info, registered := ctx.TypesReverse[first]
|
||||||
|
if registered {
|
||||||
|
return info.Reflect, left, nil
|
||||||
|
} else {
|
||||||
|
switch first {
|
||||||
|
case SerializeType(reflect.Map):
|
||||||
|
key_type, after_key, err := UnwrapStack(ctx, left)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
elem_type, after_elem, err := UnwrapStack(ctx, after_key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.MapOf(key_type, elem_type), after_elem, nil
|
||||||
|
case SerializeType(reflect.Pointer):
|
||||||
|
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return reflect.PointerTo(elem_type), rest, nil
|
||||||
|
case SerializeType(reflect.Slice):
|
||||||
|
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return reflect.SliceOf(elem_type), rest, nil
|
||||||
|
case SerializeType(reflect.Array):
|
||||||
|
length_bytes, left := split(left, 8)
|
||||||
|
length := int(binary.BigEndian.Uint64(length_bytes))
|
||||||
|
|
||||||
|
elem_type, rest, err := UnwrapStack(ctx, left)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.ArrayOf(length, elem_type), rest, nil
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("Type stack %+v not recognized", stack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Serialize[T any](ctx *Context, value T, data []byte) (int, error) {
|
||||||
|
return SerializeValue(ctx, reflect.ValueOf(&value).Elem(), data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Deserialize[T any](ctx *Context, data []byte) (T, error) {
|
||||||
|
reflect_type := reflect.TypeFor[T]()
|
||||||
|
var zero T
|
||||||
|
value, left, err := DeserializeValue(ctx, data, reflect_type)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
} else if len(left) != 0 {
|
||||||
|
return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
|
||||||
|
} else if value.Type() != reflect_type {
|
||||||
|
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Interface().(T), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SerializedSize(ctx *Context, value reflect.Value) (int, error) {
|
||||||
|
var sizefn SerializedSizeFn = nil
|
||||||
|
|
||||||
|
info, registered := ctx.Types[value.Type()]
|
||||||
|
if registered {
|
||||||
|
sizefn = info.SerializedSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if sizefn == nil {
|
||||||
|
switch value.Type().Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return 1, nil
|
||||||
|
|
||||||
|
case reflect.Int8:
|
||||||
|
return 1, nil
|
||||||
|
case reflect.Int16:
|
||||||
|
return 2, nil
|
||||||
|
case reflect.Int32:
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Int:
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.Uint8:
|
||||||
|
return 1, nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
return 2, nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Uint:
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
return 8 + value.Len(), nil
|
||||||
|
|
||||||
|
case reflect.Pointer:
|
||||||
|
if value.IsNil() {
|
||||||
|
return 1, nil
|
||||||
|
} else {
|
||||||
|
elem_len, err := SerializedSize(ctx, value.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
return 1 + elem_len, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if value.IsNil() {
|
||||||
|
return 1, nil
|
||||||
|
} else {
|
||||||
|
elem_total := 0
|
||||||
|
for i := 0; i < value.Len(); i++ {
|
||||||
|
elem_len, err := SerializedSize(ctx, value.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
elem_total += elem_len
|
||||||
|
}
|
||||||
|
return 9 + elem_total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
total := 0
|
||||||
|
for i := 0; i < value.Len(); i++ {
|
||||||
|
elem_len, err := SerializedSize(ctx, value.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total += elem_len
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
if value.IsNil() {
|
||||||
|
return 1, nil
|
||||||
|
} else {
|
||||||
|
key := reflect.New(value.Type().Key()).Elem()
|
||||||
|
val := reflect.New(value.Type().Elem()).Elem()
|
||||||
|
iter := value.MapRange()
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for iter.Next() {
|
||||||
|
key.SetIterKey(iter)
|
||||||
|
k, err := SerializedSize(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
total += k
|
||||||
|
|
||||||
|
val.SetIterValue(iter)
|
||||||
|
v, err := SerializedSize(ctx, val)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
total += v
|
||||||
|
}
|
||||||
|
|
||||||
|
return 9 + total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
if registered == false {
|
||||||
|
return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type())
|
||||||
|
} else {
|
||||||
|
field_total := 0
|
||||||
|
for _, field_info := range(info.Fields) {
|
||||||
|
field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
field_total += 8
|
||||||
|
field_total += field_size
|
||||||
|
}
|
||||||
|
|
||||||
|
return 8 + field_total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
// TODO get size of TypeStack instead of just using 128
|
||||||
|
elem_size, err := SerializedSize(ctx, value.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return 128 + elem_size, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return sizefn(ctx, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) {
|
||||||
|
var serialize SerializeFn = nil
|
||||||
|
|
||||||
|
info, registered := ctx.Types[value.Type()]
|
||||||
|
if registered {
|
||||||
|
serialize = info.Serialize
|
||||||
|
}
|
||||||
|
|
||||||
|
if serialize == nil {
|
||||||
|
switch value.Type().Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
if value.Bool() {
|
||||||
|
data[0] = 0xFF
|
||||||
|
} else {
|
||||||
|
data[0] = 0x00
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
|
|
||||||
|
case reflect.Int8:
|
||||||
|
data[0] = byte(value.Int())
|
||||||
|
return 1, nil
|
||||||
|
case reflect.Int16:
|
||||||
|
binary.BigEndian.PutUint16(data, uint16(value.Int()))
|
||||||
|
return 2, nil
|
||||||
|
case reflect.Int32:
|
||||||
|
binary.BigEndian.PutUint32(data, uint32(value.Int()))
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Int:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(value.Int()))
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.Uint8:
|
||||||
|
data[0] = byte(value.Uint())
|
||||||
|
return 1, nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
binary.BigEndian.PutUint16(data, uint16(value.Uint()))
|
||||||
|
return 2, nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
binary.BigEndian.PutUint32(data, uint32(value.Uint()))
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Uint:
|
||||||
|
binary.BigEndian.PutUint64(data, value.Uint())
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
binary.BigEndian.PutUint32(data, math.Float32bits(float32(value.Float())))
|
||||||
|
return 4, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
binary.BigEndian.PutUint64(data, math.Float64bits(value.Float()))
|
||||||
|
return 8, nil
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(value.Len()))
|
||||||
|
copy(data[8:], []byte(value.String()))
|
||||||
|
return 8 + value.Len(), nil
|
||||||
|
|
||||||
|
case reflect.Pointer:
|
||||||
|
if value.IsNil() {
|
||||||
|
data[0] = 0x00
|
||||||
|
return 1, nil
|
||||||
|
} else {
|
||||||
|
data[0] = 0x01
|
||||||
|
written, err := SerializeValue(ctx, value.Elem(), data[1:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return 1 + written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if value.IsNil() {
|
||||||
|
data[0] = 0x00
|
||||||
|
return 8, nil
|
||||||
|
} else {
|
||||||
|
data[0] = 0x01
|
||||||
|
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
||||||
|
total_written := 0
|
||||||
|
for i := 0; i < value.Len(); i++ {
|
||||||
|
written, err := SerializeValue(ctx, value.Index(i), data[9+total_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total_written += written
|
||||||
|
}
|
||||||
|
return 9 + total_written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
total_written := 0
|
||||||
|
for i := 0; i < value.Len(); i++ {
|
||||||
|
written, err := SerializeValue(ctx, value.Index(i), data[total_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total_written += written
|
||||||
|
}
|
||||||
|
return total_written, nil
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
if value.IsNil() {
|
||||||
|
data[0] = 0x00
|
||||||
|
return 1, nil
|
||||||
|
} else {
|
||||||
|
data[0] = 0x01
|
||||||
|
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
|
||||||
|
|
||||||
|
key := reflect.New(value.Type().Key()).Elem()
|
||||||
|
val := reflect.New(value.Type().Elem()).Elem()
|
||||||
|
iter := value.MapRange()
|
||||||
|
total_written := 0
|
||||||
|
for iter.Next() {
|
||||||
|
key.SetIterKey(iter)
|
||||||
|
val.SetIterValue(iter)
|
||||||
|
|
||||||
|
k, err := SerializeValue(ctx, key, data[9+total_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total_written += k
|
||||||
|
|
||||||
|
v, err := SerializeValue(ctx, val, data[9+total_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total_written += v
|
||||||
|
}
|
||||||
|
return 9 + total_written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
if registered == false {
|
||||||
|
return 0, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type())
|
||||||
|
} else {
|
||||||
|
binary.BigEndian.PutUint64(data, uint64(len(info.Fields)))
|
||||||
|
|
||||||
|
total_written := 0
|
||||||
|
for field_tag, field_info := range(info.Fields) {
|
||||||
|
binary.BigEndian.PutUint64(data[8+total_written:], uint64(field_tag))
|
||||||
|
total_written += 8
|
||||||
|
written, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index), data[8+total_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total_written += written
|
||||||
|
}
|
||||||
|
return 8 + total_written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
type_written, err := TypeStack(ctx, value.Elem().Type(), data)
|
||||||
|
|
||||||
|
elem_written, err := SerializeValue(ctx, value.Elem(), data[type_written:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return type_written + elem_written, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return serialize(ctx, value, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func split(data []byte, n int) ([]byte, []byte) {
|
||||||
|
return data[:n], data[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) {
|
||||||
|
var deserialize DeserializeFn = nil
|
||||||
|
|
||||||
|
info, registered := ctx.Types[t]
|
||||||
|
if registered {
|
||||||
|
deserialize = info.Deserialize
|
||||||
|
}
|
||||||
|
|
||||||
|
if deserialize == nil {
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
used, left := split(data, 1)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetBool(used[0] != 0x00)
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.Int8:
|
||||||
|
used, left := split(data, 1)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetInt(int64(used[0]))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Int16:
|
||||||
|
used, left := split(data, 2)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetInt(int64(binary.BigEndian.Uint16(used)))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Int32:
|
||||||
|
used, left := split(data, 4)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetInt(int64(binary.BigEndian.Uint32(used)))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Int:
|
||||||
|
used, left := split(data, 8)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetInt(int64(binary.BigEndian.Uint64(used)))
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.Uint8:
|
||||||
|
used, left := split(data, 1)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetUint(uint64(used[0]))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
used, left := split(data, 2)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetUint(uint64(binary.BigEndian.Uint16(used)))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
used, left := split(data, 4)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetUint(uint64(binary.BigEndian.Uint32(used)))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Uint:
|
||||||
|
used, left := split(data, 8)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetUint(binary.BigEndian.Uint64(used))
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
used, left := split(data, 4)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetFloat(float64(math.Float32frombits(binary.BigEndian.Uint32(used))))
|
||||||
|
return value, left, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
used, left := split(data, 8)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetFloat(math.Float64frombits(binary.BigEndian.Uint64(used)))
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
length, after_len := split(data, 8)
|
||||||
|
used, left := split(after_len, int(binary.BigEndian.Uint64(length)))
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
value.SetString(string(used))
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.Pointer:
|
||||||
|
flags, after_flags := split(data, 1)
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
if flags[0] == 0x00 {
|
||||||
|
value.SetZero()
|
||||||
|
return value, after_flags, nil
|
||||||
|
} else {
|
||||||
|
elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
value.Set(elem_value.Addr())
|
||||||
|
return value, after_elem, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
nil_byte := data[0]
|
||||||
|
data = data[1:]
|
||||||
|
if nil_byte == 0x00 {
|
||||||
|
return reflect.New(t).Elem(), data, nil
|
||||||
|
} else {
|
||||||
|
len_bytes, left := split(data, 8)
|
||||||
|
length := int(binary.BigEndian.Uint64(len_bytes))
|
||||||
|
value := reflect.MakeSlice(t, length, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
var elem_value reflect.Value
|
||||||
|
var err error
|
||||||
|
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
value.Index(i).Set(elem_value)
|
||||||
|
}
|
||||||
|
return value, left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
left := data
|
||||||
|
for i := 0; i < t.Len(); i++ {
|
||||||
|
var elem_value reflect.Value
|
||||||
|
var err error
|
||||||
|
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
value.Index(i).Set(elem_value)
|
||||||
|
}
|
||||||
|
return value, left, nil
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
flags, after_flags := split(data, 1)
|
||||||
|
if flags[0] == 0x00 {
|
||||||
|
return reflect.New(t).Elem(), after_flags, nil
|
||||||
|
} else {
|
||||||
|
len_bytes, left := split(after_flags, 8)
|
||||||
|
length := int(binary.BigEndian.Uint64(len_bytes))
|
||||||
|
|
||||||
|
value := reflect.MakeMapWithSize(t, length)
|
||||||
|
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
var key_value reflect.Value
|
||||||
|
var val_value reflect.Value
|
||||||
|
var err error
|
||||||
|
|
||||||
|
key_value, left, err = DeserializeValue(ctx, left, t.Key())
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
val_value, left, err = DeserializeValue(ctx, left, t.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value.SetMapIndex(key_value, val_value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
info, mapped := ctx.Types[t]
|
||||||
|
if mapped {
|
||||||
|
value := reflect.New(t).Elem()
|
||||||
|
|
||||||
|
num_field_bytes, left := split(data, 8)
|
||||||
|
num_fields := int(binary.BigEndian.Uint64(num_field_bytes))
|
||||||
|
|
||||||
|
for i := 0; i < num_fields; i++ {
|
||||||
|
var tag_bytes []byte
|
||||||
|
|
||||||
|
tag_bytes, left = split(left, 8)
|
||||||
|
field_tag := FieldTag(binary.BigEndian.Uint64(tag_bytes))
|
||||||
|
|
||||||
|
field_info, mapped := info.Fields[field_tag]
|
||||||
|
if mapped {
|
||||||
|
var field_val reflect.Value
|
||||||
|
var err error
|
||||||
|
field_val, left, err = DeserializeValue(ctx, left, field_info.Type)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
value.FieldByIndex(field_info.Index).Set(field_val)
|
||||||
|
} else {
|
||||||
|
return reflect.Value{}, nil, fmt.Errorf("Unknown field %s on struct %s", field_tag, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.PostDeserializeIndex != -1 {
|
||||||
|
post_deserialize_method := value.Addr().Method(info.PostDeserializeIndex)
|
||||||
|
post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
||||||
|
}
|
||||||
|
return value, left, nil
|
||||||
|
} else {
|
||||||
|
return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
elem_type, rest, err := UnwrapStack(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
elem_val, left, err := DeserializeValue(ctx, rest, elem_type)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.New(t).Elem()
|
||||||
|
val.Set(elem_val)
|
||||||
|
|
||||||
|
return val, left, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return deserialize(ctx, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,176 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"reflect"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testTypeStack[T any](t *testing.T, ctx *Context) {
|
||||||
|
buffer := [1024]byte{}
|
||||||
|
reflect_type := reflect.TypeFor[T]()
|
||||||
|
written, err := TypeStack(ctx, reflect_type, buffer[:])
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
stack := buffer[:written]
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
|
||||||
|
|
||||||
|
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
if len(rest) != 0 {
|
||||||
|
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unwrapped_type != reflect_type {
|
||||||
|
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeTypes(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
|
testTypeStack[int](t, ctx)
|
||||||
|
testTypeStack[map[int]string](t, ctx)
|
||||||
|
testTypeStack[string](t, ctx)
|
||||||
|
testTypeStack[*string](t, ctx)
|
||||||
|
testTypeStack[*map[string]*map[*string]int](t, ctx)
|
||||||
|
testTypeStack[[5]int](t, ctx)
|
||||||
|
testTypeStack[uuid.UUID](t, ctx)
|
||||||
|
testTypeStack[NodeID](t, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
|
||||||
|
buffer := [1024]byte{}
|
||||||
|
written, err := Serialize(ctx, value, buffer[:])
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
serialized := buffer[:written]
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||||
|
|
||||||
|
deserialized, err := Deserialize[T](ctx, serialized)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
if value != deserialized {
|
||||||
|
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
|
||||||
|
buffer := [1024]byte{}
|
||||||
|
written, err := Serialize(ctx, value, buffer[:])
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
serialized := buffer[:written]
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
|
||||||
|
|
||||||
|
deserialized, err := Deserialize[L](ctx, serialized)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
for i, item := range(value) {
|
||||||
|
if item != deserialized[i] {
|
||||||
|
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
|
||||||
|
buffer := [1024]byte{}
|
||||||
|
|
||||||
|
written, err := Serialize(ctx, value, buffer[:])
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
serialized := buffer[:written]
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
|
||||||
|
|
||||||
|
deserialized, err := Deserialize[P](ctx, serialized)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
if value == nil && deserialized == nil {
|
||||||
|
ctx.Log.Logf("test", "Deserialized nil")
|
||||||
|
} else if value == nil {
|
||||||
|
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
|
||||||
|
} else if deserialized == nil {
|
||||||
|
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
|
||||||
|
} else if *deserialized != *value {
|
||||||
|
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
|
||||||
|
buffer := [1024]byte{}
|
||||||
|
written, err := Serialize(ctx, value, buffer[:])
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
serialized := buffer[:written]
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||||
|
|
||||||
|
deserialized, err := Deserialize[T](ctx, serialized)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeValues(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
|
||||||
|
|
||||||
|
testSerializeCompare[int8](t, ctx, -64)
|
||||||
|
testSerializeCompare[int16](t, ctx, -64)
|
||||||
|
testSerializeCompare[int32](t, ctx, -64)
|
||||||
|
testSerializeCompare[int64](t, ctx, -64)
|
||||||
|
testSerializeCompare[int](t, ctx, -64)
|
||||||
|
|
||||||
|
testSerializeCompare[uint8](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint16](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint32](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint64](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint](t, ctx, 64)
|
||||||
|
|
||||||
|
testSerializeCompare[string](t, ctx, "test")
|
||||||
|
|
||||||
|
a := 12
|
||||||
|
testSerializePointer[*int](t, ctx, &a)
|
||||||
|
|
||||||
|
b := "test"
|
||||||
|
testSerializePointer[*string](t, ctx, nil)
|
||||||
|
testSerializePointer[*string](t, ctx, &b)
|
||||||
|
|
||||||
|
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
|
||||||
|
|
||||||
|
testSerializeCompare[bool](t, ctx, true)
|
||||||
|
testSerializeCompare[bool](t, ctx, false)
|
||||||
|
testSerializeCompare[int](t, ctx, -1)
|
||||||
|
testSerializeCompare[uint](t, ctx, 1)
|
||||||
|
testSerializeCompare[NodeID](t, ctx, RandID())
|
||||||
|
testSerializeCompare[*int](t, ctx, nil)
|
||||||
|
testSerializeCompare(t, ctx, "string")
|
||||||
|
|
||||||
|
testSerialize(t, ctx, map[string]string{
|
||||||
|
"Test": "Test",
|
||||||
|
"key": "String",
|
||||||
|
"": "",
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerialize[map[string]string](t, ctx, nil)
|
||||||
|
|
||||||
|
testSerialize(t, ctx, NewListenerExt(10))
|
||||||
|
|
||||||
|
node, err := ctx.NewNode(nil, "Node")
|
||||||
|
fatalErr(t, err)
|
||||||
|
testSerialize(t, ctx, node)
|
||||||
|
}
|
@ -1,111 +1,261 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"fmt"
|
||||||
)
|
"time"
|
||||||
|
|
||||||
type SignalDirection int
|
"github.com/google/uuid"
|
||||||
const (
|
|
||||||
Up SignalDirection = iota
|
|
||||||
Down
|
|
||||||
Direct
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SignalType string
|
type TimeoutSignal struct {
|
||||||
|
ResponseHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal {
|
||||||
|
return &TimeoutSignal{
|
||||||
|
NewResponseHeader(req_id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal TimeoutSignal) String() string {
|
||||||
|
return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SignalHeader struct {
|
||||||
|
Id uuid.UUID `gv:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal SignalHeader) ID() uuid.UUID {
|
||||||
|
return signal.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (header SignalHeader) String() string {
|
||||||
|
return fmt.Sprintf("%s", header.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseSignal interface {
|
||||||
|
Signal
|
||||||
|
ResponseID() uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseHeader struct {
|
||||||
|
SignalHeader
|
||||||
|
ReqID uuid.UUID `gv:"req_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (header ResponseHeader) ResponseID() uuid.UUID {
|
||||||
|
return header.ReqID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (header ResponseHeader) String() string {
|
||||||
|
return fmt.Sprintf("%s for %s", header.Id, header.ReqID)
|
||||||
|
}
|
||||||
|
|
||||||
type Signal interface {
|
type Signal interface {
|
||||||
Serializable[SignalType]
|
fmt.Stringer
|
||||||
Direction() SignalDirection
|
ID() uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseSignal struct {
|
func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) {
|
||||||
SignalDirection SignalDirection `json:"direction"`
|
signals := []Signal{}
|
||||||
SignalType SignalType `json:"type"`
|
var timeout_channel <- chan time.Time
|
||||||
|
if timeout > 0 {
|
||||||
|
timeout_channel = time.After(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (signal BaseSignal) Type() SignalType {
|
for true {
|
||||||
return signal.SignalType
|
select {
|
||||||
|
case signal := <- listener:
|
||||||
|
if signal == nil {
|
||||||
|
return nil, signals, fmt.Errorf("LISTENER_CLOSED")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (signal BaseSignal) Direction() SignalDirection {
|
resp_signal, ok := signal.(ResponseSignal)
|
||||||
return signal.SignalDirection
|
if ok == true && resp_signal.ResponseID() == req_id {
|
||||||
|
return resp_signal, signals, nil
|
||||||
|
} else {
|
||||||
|
signals = append(signals, signal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (signal BaseSignal) Serialize() ([]byte, error) {
|
case <-timeout_channel:
|
||||||
return json.MarshalIndent(signal, "", " ")
|
return nil, signals, fmt.Errorf("LISTENER_TIMEOUT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, signals, fmt.Errorf("UNREACHABLE")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal {
|
//TODO: Add []Signal return as well for other signals
|
||||||
signal := BaseSignal{
|
func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check func(S)bool) (S, error) {
|
||||||
SignalDirection: direction,
|
var zero S
|
||||||
SignalType: signal_type,
|
var timeout_channel <- chan time.Time
|
||||||
|
if timeout > 0 {
|
||||||
|
timeout_channel = time.After(timeout)
|
||||||
|
}
|
||||||
|
for true {
|
||||||
|
select {
|
||||||
|
case signal := <- listener:
|
||||||
|
if signal == nil {
|
||||||
|
return zero, fmt.Errorf("LISTENER_CLOSED")
|
||||||
}
|
}
|
||||||
return signal
|
sig, ok := signal.(S)
|
||||||
|
if ok == true {
|
||||||
|
if check(sig) == true {
|
||||||
|
return sig, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-timeout_channel:
|
||||||
|
return zero, fmt.Errorf("LISTENER_TIMEOUT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return zero, fmt.Errorf("LOOP_ENDED")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDownSignal(signal_type SignalType) BaseSignal {
|
func NewSignalHeader() SignalHeader {
|
||||||
return NewBaseSignal(signal_type, Down)
|
return SignalHeader{
|
||||||
|
uuid.New(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUpSignal(signal_type SignalType) BaseSignal {
|
func NewResponseHeader(req_id uuid.UUID) ResponseHeader {
|
||||||
return NewBaseSignal(signal_type, Up)
|
return ResponseHeader{
|
||||||
|
NewSignalHeader(),
|
||||||
|
req_id,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDirectSignal(signal_type SignalType) BaseSignal {
|
type SuccessSignal struct {
|
||||||
return NewBaseSignal(signal_type, Direct)
|
ResponseHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
var AbortSignal = NewBaseSignal("abort", Down)
|
func (signal SuccessSignal) String() string {
|
||||||
var StopSignal = NewBaseSignal("stop", Down)
|
return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader)
|
||||||
|
}
|
||||||
|
|
||||||
type IDSignal struct {
|
func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal {
|
||||||
BaseSignal
|
return &SuccessSignal{
|
||||||
ID NodeID `json:"id"`
|
NewResponseHeader(req_id),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (signal IDSignal) String() string {
|
type ErrorSignal struct {
|
||||||
ser, err := json.Marshal(signal)
|
ResponseHeader
|
||||||
if err != nil {
|
Error string
|
||||||
return "STATE_SER_ERR"
|
}
|
||||||
|
func (signal ErrorSignal) String() string {
|
||||||
|
return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error)
|
||||||
|
}
|
||||||
|
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal {
|
||||||
|
return &ErrorSignal{
|
||||||
|
NewResponseHeader(req_id),
|
||||||
|
fmt.Sprintf(fmt_string, args...),
|
||||||
}
|
}
|
||||||
return string(ser)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
|
type ACLTimeoutSignal struct {
|
||||||
return IDSignal{
|
ResponseHeader
|
||||||
BaseSignal: NewBaseSignal(signal_type, direction),
|
}
|
||||||
ID: id,
|
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
|
||||||
|
sig := &ACLTimeoutSignal{
|
||||||
|
NewResponseHeader(req_id),
|
||||||
}
|
}
|
||||||
|
return sig
|
||||||
}
|
}
|
||||||
|
|
||||||
type StatusSignal struct {
|
type StatusSignal struct {
|
||||||
IDSignal
|
SignalHeader
|
||||||
Status string `json:"status"`
|
Source NodeID `gv:"source"`
|
||||||
|
Fields []string `gv:"fields"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (signal StatusSignal) String() string {
|
func (signal StatusSignal) String() string {
|
||||||
ser, err := json.Marshal(signal)
|
return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Fields)
|
||||||
if err != nil {
|
}
|
||||||
return "STATE_SER_ERR"
|
func NewStatusSignal(source NodeID, fields []string) *StatusSignal {
|
||||||
|
return &StatusSignal{
|
||||||
|
NewSignalHeader(),
|
||||||
|
source,
|
||||||
|
fields,
|
||||||
}
|
}
|
||||||
return string(ser)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStatusSignal(status string, source NodeID) StatusSignal {
|
type LinkSignal struct {
|
||||||
return StatusSignal{
|
SignalHeader
|
||||||
IDSignal: NewIDSignal("status", Up, source),
|
NodeID NodeID
|
||||||
Status: status,
|
Action string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
LinkActionBase = "LINK_ACTION"
|
||||||
|
LinkActionAdd = "ADD"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLinkSignal(action string, id NodeID) Signal {
|
||||||
|
return &LinkSignal{
|
||||||
|
NewSignalHeader(),
|
||||||
|
id,
|
||||||
|
action,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type LockSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
}
|
||||||
|
func (signal LockSignal) String() string {
|
||||||
|
return fmt.Sprintf("LockSignal(%s)", signal.SignalHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StartChildSignal struct {
|
func NewLockSignal() *LockSignal {
|
||||||
IDSignal
|
return &LockSignal{
|
||||||
Action string `json:"action"`
|
NewSignalHeader(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnlockSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
}
|
||||||
|
func (signal UnlockSignal) String() string {
|
||||||
|
return fmt.Sprintf("UnlockSignal(%s)", signal.SignalHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStartChildSignal(child_id NodeID, action string) StartChildSignal {
|
func NewUnlockSignal() *UnlockSignal {
|
||||||
return StartChildSignal{
|
return &UnlockSignal{
|
||||||
IDSignal: NewIDSignal("start_child", Direct, child_id),
|
NewSignalHeader(),
|
||||||
Action: action,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
type ReadSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
Fields []string `json:"extensions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal ReadSignal) String() string {
|
||||||
|
return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadSignal(fields []string) *ReadSignal {
|
||||||
|
return &ReadSignal{
|
||||||
|
NewSignalHeader(),
|
||||||
|
fields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReadResultSignal struct {
|
||||||
|
ResponseHeader
|
||||||
|
NodeID NodeID
|
||||||
|
NodeType NodeType
|
||||||
|
Fields map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal ReadResultSignal) String() string {
|
||||||
|
return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, fields map[string]any) *ReadResultSignal {
|
||||||
|
return &ReadResultSignal{
|
||||||
|
NewResponseHeader(req_id),
|
||||||
|
node_id,
|
||||||
|
node_type,
|
||||||
|
fields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,736 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
"sync"
|
|
||||||
"errors"
|
|
||||||
"encoding/json"
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/binary"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
|
|
||||||
type ThreadActions map[string]ThreadAction
|
|
||||||
type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error)
|
|
||||||
type ThreadHandlers map[SignalType]ThreadHandler
|
|
||||||
|
|
||||||
type InfoType string
|
|
||||||
func (t InfoType) String() string {
|
|
||||||
return string(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Info interface {
|
|
||||||
Serializable[InfoType]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Data required by a parent thread to restore it's children
|
|
||||||
type ParentInfo struct {
|
|
||||||
Start bool `json:"start"`
|
|
||||||
StartAction string `json:"start_action"`
|
|
||||||
RestoreAction string `json:"restore_action"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParentInfoType = InfoType("PARENT")
|
|
||||||
func (info *ParentInfo) Type() InfoType {
|
|
||||||
return ParentInfoType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (info *ParentInfo) Serialize() ([]byte, error) {
|
|
||||||
return json.MarshalIndent(info, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueuedAction struct {
|
|
||||||
Timeout time.Time `json:"time"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ThreadType string
|
|
||||||
func (thread ThreadType) Hash() uint64 {
|
|
||||||
hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread))))
|
|
||||||
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
|
|
||||||
}
|
|
||||||
|
|
||||||
type ThreadInfo struct {
|
|
||||||
Actions ThreadActions
|
|
||||||
Handlers ThreadHandlers
|
|
||||||
}
|
|
||||||
|
|
||||||
type InfoLoadFunc func([]byte)(Info, error)
|
|
||||||
type ThreadExtContext struct {
|
|
||||||
Types map[ThreadType]ThreadInfo
|
|
||||||
Loads map[InfoType]InfoLoadFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
const BaseThreadType = ThreadType("BASE")
|
|
||||||
func NewThreadExtContext() *ThreadExtContext {
|
|
||||||
return &ThreadExtContext{
|
|
||||||
Types: map[ThreadType]ThreadInfo{
|
|
||||||
BaseThreadType: ThreadInfo{
|
|
||||||
Actions: BaseThreadActions,
|
|
||||||
Handlers: BaseThreadHandlers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Loads: map[InfoType]InfoLoadFunc{
|
|
||||||
ParentInfoType: func(data []byte) (Info, error) {
|
|
||||||
var info ParentInfo
|
|
||||||
err := json.Unmarshal(data, &info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &info, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error {
|
|
||||||
if actions == nil || handlers == nil {
|
|
||||||
return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, exists := ctx.Types[thread_type]
|
|
||||||
if exists == true {
|
|
||||||
return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type)
|
|
||||||
}
|
|
||||||
ctx.Types[thread_type] = ThreadInfo{
|
|
||||||
Actions: actions,
|
|
||||||
Handlers: handlers,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error {
|
|
||||||
if load_fn == nil {
|
|
||||||
return fmt.Errorf("Cannot register %s with nil load_fn", info_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, exists := ctx.Loads[info_type]
|
|
||||||
if exists == true {
|
|
||||||
return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Loads[info_type] = load_fn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ThreadExt struct {
|
|
||||||
Actions ThreadActions
|
|
||||||
Handlers ThreadHandlers
|
|
||||||
|
|
||||||
ThreadType ThreadType
|
|
||||||
|
|
||||||
SignalChan chan Signal
|
|
||||||
TimeoutChan <-chan time.Time
|
|
||||||
|
|
||||||
ChildWaits sync.WaitGroup
|
|
||||||
|
|
||||||
ActiveLock sync.Mutex
|
|
||||||
Active bool
|
|
||||||
State string
|
|
||||||
|
|
||||||
Parent *Node
|
|
||||||
Children map[NodeID]ChildInfo
|
|
||||||
|
|
||||||
ActionQueue []QueuedAction
|
|
||||||
NextAction *QueuedAction
|
|
||||||
}
|
|
||||||
|
|
||||||
type ThreadExtJSON struct {
|
|
||||||
State string `json:"state"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Parent string `json:"parent"`
|
|
||||||
Children map[string]map[string][]byte `json:"children"`
|
|
||||||
ActionQueue []QueuedAction
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) Serialize() ([]byte, error) {
|
|
||||||
children := map[string]map[string][]byte{}
|
|
||||||
for id, child := range(ext.Children) {
|
|
||||||
id_str := id.String()
|
|
||||||
children[id_str] = map[string][]byte{}
|
|
||||||
for info_type, info := range(child.Infos) {
|
|
||||||
var err error
|
|
||||||
children[id_str][string(info_type)], err = info.Serialize()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.MarshalIndent(&ThreadExtJSON{
|
|
||||||
State: ext.State,
|
|
||||||
Type: string(ext.ThreadType),
|
|
||||||
Parent: SaveNode(ext.Parent),
|
|
||||||
Children: children,
|
|
||||||
ActionQueue: ext.ActionQueue,
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) {
|
|
||||||
if children == nil {
|
|
||||||
children = map[NodeID]ChildInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if action_queue == nil {
|
|
||||||
action_queue = []QueuedAction{}
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
type_info, exists := thread_ctx.Types[thread_type]
|
|
||||||
if exists == false {
|
|
||||||
return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type)
|
|
||||||
}
|
|
||||||
next_action, timeout_chan := SoonestAction(action_queue)
|
|
||||||
|
|
||||||
return &ThreadExt{
|
|
||||||
ThreadType: thread_type,
|
|
||||||
Actions: type_info.Actions,
|
|
||||||
Handlers: type_info.Handlers,
|
|
||||||
SignalChan: make(chan Signal, THREAD_BUFFER_SIZE),
|
|
||||||
TimeoutChan: timeout_chan,
|
|
||||||
Active: false,
|
|
||||||
State: state,
|
|
||||||
Parent: parent,
|
|
||||||
Children: children,
|
|
||||||
ActionQueue: action_queue,
|
|
||||||
NextAction: next_action,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const THREAD_BUFFER_SIZE int = 1024
|
|
||||||
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j ThreadExtJSON
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ctx.Log.Logf("db", "DB_LOAD_THREAD_EXT_JSON: %+v", j)
|
|
||||||
|
|
||||||
parent, err := RestoreNode(ctx, j.Parent)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
children := map[NodeID]ChildInfo{}
|
|
||||||
for id_str, infos := range(j.Children) {
|
|
||||||
child_node, err := RestoreNode(ctx, id_str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
child_infos := map[InfoType]Info{}
|
|
||||||
for info_type_str, info_data := range(infos) {
|
|
||||||
info_type := InfoType(info_type_str)
|
|
||||||
info_load, exists := thread_ctx.Loads[info_type]
|
|
||||||
if exists == false {
|
|
||||||
return nil, fmt.Errorf("%s is not a known InfoType in ThreacExrContxt", info_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := info_load(info_data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
child_infos[info_type] = info
|
|
||||||
}
|
|
||||||
|
|
||||||
children[child_node.ID] = ChildInfo{
|
|
||||||
Child: child_node,
|
|
||||||
Infos: child_infos,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue)
|
|
||||||
}
|
|
||||||
|
|
||||||
const ThreadExtType = ExtType("THREAD")
|
|
||||||
func (ext *ThreadExt) Type() ExtType {
|
|
||||||
return ThreadExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
|
|
||||||
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
|
|
||||||
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) ClearActionQueue() {
|
|
||||||
ext.ActionQueue = []QueuedAction{}
|
|
||||||
ext.NextAction = nil
|
|
||||||
ext.TimeoutChan = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
|
|
||||||
var soonest_action *QueuedAction
|
|
||||||
var soonest_time time.Time
|
|
||||||
for _, action := range(actions) {
|
|
||||||
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
|
|
||||||
soonest_action = &action
|
|
||||||
soonest_time = action.Timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if soonest_action != nil {
|
|
||||||
return soonest_action, time.After(time.Until(soonest_action.Timeout))
|
|
||||||
} else {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) ChildList() []*Node {
|
|
||||||
ret := make([]*Node, len(ext.Children))
|
|
||||||
i := 0
|
|
||||||
for _, info := range(ext.Children) {
|
|
||||||
ret[i] = info.Child
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assumed that thread is already locked for signal
|
|
||||||
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error {
|
|
||||||
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
switch signal.Direction() {
|
|
||||||
case Up:
|
|
||||||
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
|
|
||||||
if ext.Parent != nil {
|
|
||||||
if ext.Parent.ID != node.ID {
|
|
||||||
return SendSignal(context, ext.Parent, node, signal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
case Down:
|
|
||||||
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
|
|
||||||
for _, info := range(ext.Children) {
|
|
||||||
err := SendSignal(context, info.Child, node, signal)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
case Direct:
|
|
||||||
err = nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("Invalid signal direction %d", signal.Direction())
|
|
||||||
}
|
|
||||||
ext.SignalChan <- signal
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
child_ext, err := GetExt[*ThreadExt](child)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return UpdateStates(context, principal, ACLMap{
|
|
||||||
thread.ID: ACLInfo{thread, []string{"children"}},
|
|
||||||
child.ID: ACLInfo{child, []string{"parent"}},
|
|
||||||
}, func(context *StateContext) error {
|
|
||||||
_, is_child := thread_ext.Children[child.ID]
|
|
||||||
if is_child == false {
|
|
||||||
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(thread_ext.Children, child.ID)
|
|
||||||
child_ext.Parent = nil
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
|
|
||||||
for _, info := range(cur.Children) {
|
|
||||||
child := info.Child
|
|
||||||
if child.ID == id {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
child_ext, err := GetExt[*ThreadExt](child)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var is_child bool
|
|
||||||
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
|
|
||||||
is_child, err = checkIfChild(context, id, child_ext)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if is_child {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Links child to parent with info as the associated info
|
|
||||||
// Continues the write context with princ, getting children for thread and parent for child
|
|
||||||
func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
|
|
||||||
if context == nil || principal == nil || thread == nil || info.Child == nil {
|
|
||||||
return fmt.Errorf("invalid input")
|
|
||||||
}
|
|
||||||
|
|
||||||
child := info.Child
|
|
||||||
if thread.ID == child.ID {
|
|
||||||
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
child_ext, err := GetExt[*ThreadExt](child)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return UpdateStates(context, principal, ACLMap{
|
|
||||||
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
|
|
||||||
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
|
|
||||||
}, func(context *StateContext) error {
|
|
||||||
if child_ext.Parent != nil {
|
|
||||||
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
is_child, err := checkIfChild(context, thread.ID, child_ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if is_child == true {
|
|
||||||
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
is_child, err = checkIfChild(context, child.ID, thread_ext)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
return err
|
|
||||||
} else if is_child == true {
|
|
||||||
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO check for info types
|
|
||||||
|
|
||||||
thread_ext.Children[child.ID] = info
|
|
||||||
child_ext.Parent = thread
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChildInfo struct {
|
|
||||||
Child *Node
|
|
||||||
Infos map[InfoType]Info
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo {
|
|
||||||
if infos == nil {
|
|
||||||
infos = map[InfoType]Info{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ChildInfo{
|
|
||||||
Child: child,
|
|
||||||
Infos: infos,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) SetActive(active bool) error {
|
|
||||||
ext.ActiveLock.Lock()
|
|
||||||
defer ext.ActiveLock.Unlock()
|
|
||||||
if ext.Active == true && active == true {
|
|
||||||
return fmt.Errorf("alreday active, cannot set active")
|
|
||||||
} else if ext.Active == false && active == false {
|
|
||||||
return fmt.Errorf("already inactive, canot set inactive")
|
|
||||||
}
|
|
||||||
ext.Active = active
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ThreadExt) SetState(state string) error {
|
|
||||||
ext.State = state
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Requires the read permission of threads children
|
|
||||||
func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
|
|
||||||
if thread == nil {
|
|
||||||
panic("cannot recurse through nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if id == thread.ID {
|
|
||||||
return thread, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var found *Node = nil
|
|
||||||
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
|
||||||
for _, info := range(thread_ext.Children) {
|
|
||||||
found, err = FindChild(context, principal, info.Child, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if found != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return found, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
|
|
||||||
thread_ext.ChildWaits.Add(1)
|
|
||||||
go func(child *Node) {
|
|
||||||
defer thread_ext.ChildWaits.Done()
|
|
||||||
err := ThreadLoop(ctx, child, first_action)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
|
|
||||||
}
|
|
||||||
}(child)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
|
|
||||||
func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
|
|
||||||
thread_ext, err := GetExt[*ThreadExt](thread)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
|
|
||||||
|
|
||||||
err = thread_ext.SetActive(true)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
next_action := first_action
|
|
||||||
for next_action != "" {
|
|
||||||
action, exists := thread_ext.Actions[next_action]
|
|
||||||
if exists == false {
|
|
||||||
error_str := fmt.Sprintf("%s is not a valid action", next_action)
|
|
||||||
return errors.New(error_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
|
|
||||||
next_action, err = action(ctx, thread, thread_ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = thread_ext.SetActive(false)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
|
||||||
sig, ok := signal.(IDSignal)
|
|
||||||
if ok == false {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
info, exists := thread_ext.Children[sig.ID]
|
|
||||||
if exists == false {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
|
||||||
if exists == false {
|
|
||||||
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
|
|
||||||
}
|
|
||||||
|
|
||||||
if parent_info.Start == true {
|
|
||||||
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
}
|
|
||||||
return "wait", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to start a child from a thread during a signal handler
|
|
||||||
// Starts a write context, so cannot be called from either a write or read context
|
|
||||||
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
|
||||||
sig, ok := signal.(StartChildSignal)
|
|
||||||
if ok == false {
|
|
||||||
return "wait", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
|
||||||
info, exists:= thread_ext.Children[sig.ID]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
parent_info, exists := info.Infos["parent"].(*ParentInfo)
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
|
|
||||||
}
|
|
||||||
parent_info.Start = true
|
|
||||||
ChildGo(ctx, thread_ext, info.Child, sig.Action)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to restore threads that should be running from a parents restore action
|
|
||||||
// Starts a write context, so cannot be called from either a write or read context
|
|
||||||
func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
|
|
||||||
return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"state"}), func(context *StateContext) error {
|
|
||||||
for _, info := range(thread_ext.Children) {
|
|
||||||
child_ext, err := GetExt[*ThreadExt](info.Child)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
parent_info := info.Infos[ParentInfoType].(*ParentInfo)
|
|
||||||
if parent_info.Start == true && child_ext.State != "finished" {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
|
|
||||||
if start == true {
|
|
||||||
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
|
|
||||||
} else {
|
|
||||||
ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to be called during a threads start action, sets the thread state to started
|
|
||||||
// Starts a write context, so cannot be called from either a write or read context
|
|
||||||
// Returns "wait", nil on success, so the first return value can be ignored safely
|
|
||||||
func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
|
||||||
err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return thread_ext.SetState("started")
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
context = NewReadContext(ctx)
|
|
||||||
return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID))
|
|
||||||
}
|
|
||||||
|
|
||||||
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case signal := <- thread_ext.SignalChan:
|
|
||||||
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
|
|
||||||
signal_fn, exists := thread_ext.Handlers[signal.Type()]
|
|
||||||
if exists == true {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
|
|
||||||
return signal_fn(ctx, thread, thread_ext, signal)
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
|
|
||||||
}
|
|
||||||
case <- thread_ext.TimeoutChan:
|
|
||||||
timeout_action := ""
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
|
|
||||||
timeout_action = thread_ext.NextAction.Action
|
|
||||||
thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
|
|
||||||
}
|
|
||||||
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
|
|
||||||
return timeout_action, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
|
|
||||||
context := NewWriteContext(ctx)
|
|
||||||
return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
|
|
||||||
err := thread_ext.SetState("finished")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var ThreadAbortedError = errors.New("Thread aborted by signal")
|
|
||||||
|
|
||||||
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
|
|
||||||
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
|
||||||
context := NewReadContext(ctx)
|
|
||||||
err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return "", ThreadAbortedError
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default thread action for "stop", sends a signal and returns no error
|
|
||||||
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
|
|
||||||
context := NewReadContext(ctx)
|
|
||||||
err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
|
|
||||||
return "finish", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default thread actions
|
|
||||||
var BaseThreadActions = ThreadActions{
|
|
||||||
"wait": ThreadWait,
|
|
||||||
"start": ThreadStart,
|
|
||||||
"finish": ThreadFinish,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default thread signal handlers
|
|
||||||
var BaseThreadHandlers = ThreadHandlers{
|
|
||||||
"abort": ThreadAbort,
|
|
||||||
"stop": ThreadStop,
|
|
||||||
}
|
|
@ -1,120 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
"fmt"
|
|
||||||
"encoding/json"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/x509"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ECDHExt struct {
|
|
||||||
Granted time.Time
|
|
||||||
Pubkey *ecdsa.PublicKey
|
|
||||||
Shared []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ECDHExtJSON struct {
|
|
||||||
Granted time.Time `json:"granted"`
|
|
||||||
Pubkey []byte `json:"pubkey"`
|
|
||||||
Shared []byte `json:"shared"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const ECDHExtType = ExtType("ECDH")
|
|
||||||
func (ext *ECDHExt) Type() ExtType {
|
|
||||||
return ECDHExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ECDHExt) Serialize() ([]byte, error) {
|
|
||||||
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.MarshalIndent(&ECDHExtJSON{
|
|
||||||
Granted: ext.Granted,
|
|
||||||
Pubkey: pubkey,
|
|
||||||
Shared: ext.Shared,
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j ECDHExtJSON
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var pubkey *ecdsa.PublicKey
|
|
||||||
switch pub.(type) {
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
pubkey = pub.(*ecdsa.PublicKey)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("Invalid key type: %+v", pub)
|
|
||||||
}
|
|
||||||
|
|
||||||
extension := ECDHExt{
|
|
||||||
Granted: j.Granted,
|
|
||||||
Pubkey: pubkey,
|
|
||||||
Shared: j.Shared,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &extension, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupExt struct {
|
|
||||||
Members NodeMap
|
|
||||||
}
|
|
||||||
|
|
||||||
const GroupExtType = ExtType("GROUP")
|
|
||||||
func (ext *GroupExt) Type() ExtType {
|
|
||||||
return GroupExtType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *GroupExt) Serialize() ([]byte, error) {
|
|
||||||
return json.MarshalIndent(&struct{
|
|
||||||
Members []string `json:"members"`
|
|
||||||
}{
|
|
||||||
Members: SaveNodeList(ext.Members),
|
|
||||||
}, "", " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGroupExt(members NodeMap) *GroupExt {
|
|
||||||
if members == nil {
|
|
||||||
members = NodeMap{}
|
|
||||||
}
|
|
||||||
return &GroupExt{
|
|
||||||
Members: members,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
|
|
||||||
var j struct {
|
|
||||||
Members []string `json:"members"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.Unmarshal(data, &j)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
members, err := RestoreNodeList(ctx, j.Members)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewGroupExt(members), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue