Compare commits

..

No commits in common. "master" and "graph-rework-2" have entirely different histories.

33 changed files with 4177 additions and 4461 deletions

3
.gitignore vendored

@ -1,15 +1,12 @@
# Ignore everything
*
!/go-capnp
# But not these files...
!/.gitignore
!*.go
*.capnp.go
!go.sum
!go.mod
!*.capnp
!README.md
!LICENSE

3
.gitmodules vendored

@ -1,3 +0,0 @@
[submodule "graphql"]
path = graphql
url = https://github.com/graphql-go/graphql

@ -1,4 +0,0 @@
.PHONY: test
test:
clear && go test

@ -1,48 +0,0 @@
package main
import (
"fmt"
badger "github.com/dgraph-io/badger/v3"
gv "github.com/mekkanized/graphvent"
)
func check(err error) {
if err != nil {
panic(err)
}
}
func main() {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
check(err)
ctx, err := gv.NewContext(&gv.BadgerDB{
DB: db,
}, gv.NewConsoleLogger([]string{"test"}))
check(err)
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
check(err)
listener_ext := gv.NewListenerExt(1000)
n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
check(err)
n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
check(err)
n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
check(err)
_, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID}))
check(err)
for true {
select {
case message := <- listener_ext.Chan:
fmt.Printf("Listener Message: %+v\n", message)
}
}
}

File diff suppressed because it is too large Load Diff

284
db.go

@ -1,284 +0,0 @@
package graphvent
import (
"encoding/binary"
"fmt"
"reflect"
"sync"
badger "github.com/dgraph-io/badger/v3"
)
type Database interface {
WriteNodeInit(*Context, *Node) error
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
LoadNode(*Context, NodeID) (*Node, error)
}
const WRITE_BUFFER_SIZE = 1000000
type BadgerDB struct {
*badger.DB
sync.Mutex
buffer [WRITE_BUFFER_SIZE]byte
}
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
if node == nil {
return fmt.Errorf("Cannot serialize nil *Node")
}
return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes
id_ser, err := node.ID.MarshalBinary()
if err != nil {
return err
}
cur := 0
// Write Node value
written, err := Serialize(ctx, node, db.buffer[cur:])
if err != nil {
return err
}
err = tx.Set(id_ser, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// Write empty signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil {
return err
}
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// Write node extension list
ext_list := []ExtType{}
for ext_type := range(node.Extensions) {
ext_list = append(ext_list, ext_type)
}
written, err = Serialize(ctx, ext_list, db.buffer[cur:])
if err != nil {
return err
}
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// For each extension:
for ext_type, ext := range(node.Extensions) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
}
ext_value := reflect.ValueOf(ext).Elem()
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
// Write each field to a seperate key
for field_tag, field_info := range(ext_info.Fields) {
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
}
}
return nil
})
}
func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes
id_bytes := ([16]byte)(node.ID)
cur := 0
// Write the signal queue if it needs to be written
if node.writeSignalQueue {
node.writeSignalQueue = false
sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil {
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
}
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
}
cur += written
}
// For each ext in changes
for ext_type, changes := range(changes) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext, exists := node.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
}
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
ext_value := reflect.ValueOf(ext)
// Write each field
for _, tag := range(changes) {
field_info, exists := ext_info.Fields[tag]
if exists == false {
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
}
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
}
}
return nil
})
}
func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
var node *Node = nil
err := db.View(func(tx *badger.Txn) error {
// Get the base key bytes
id_ser, err := id.MarshalBinary()
if err != nil {
return fmt.Errorf("Failed to serialize node_id: %w", err)
}
// Get the node value
node_item, err := tx.Get(id_ser)
if err != nil {
return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
}
err = node_item.Value(func(val []byte) error {
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
node, err = Deserialize[*Node](ctx, val)
return err
})
if err != nil {
return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
}
// Get the signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_item, err := tx.Get(sigqueue_id)
if err != nil {
return fmt.Errorf("Failed to get sigqueue_id: %w", err)
}
err = sigqueue_item.Value(func(val []byte) error {
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
return err
})
if err != nil {
return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
}
// Get the extension list
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
ext_list_item, err := tx.Get(ext_list_id)
if err != nil {
return err
}
var ext_list []ExtType
ext_list_item.Value(func(val []byte) error {
ext_list, err = Deserialize[[]ExtType](ctx, val)
return err
})
// Get the extensions
for _, ext_type := range(ext_list) {
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Extension %s not in context", ext_type)
}
ext := reflect.New(ext_info.Type)
for field_tag, field_info := range(ext_info.Fields) {
field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
field_item, err := tx.Get(field_id)
if err != nil {
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
}
err = field_item.Value(func(val []byte) error {
value, _, err := DeserializeValue(ctx, val, field_info.Type)
if err != nil {
return err
}
ext.Elem().FieldByIndex(field_info.Index).Set(value)
return nil
})
if err != nil {
return err
}
}
node.Extensions[ext_type] = ext.Interface().(Extension)
}
return nil
})
if err != nil {
return nil, err
} else if node == nil {
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
}
return node, nil
}

@ -1,16 +0,0 @@
package graphvent
type Tag string
type Changes []Tag
// Extensions are data attached to nodes that process signals
type Extension interface {
// Called to process incoming signals, returning changes and messages to send
Process(*Context, *Node, NodeID, Signal) ([]Message, Changes)
// Called when the node is loaded into a context(creation or move), so extension data can be initialized
Load(*Context, *Node) error
// Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up
Unload(*Context, *Node)
}

@ -1,6 +1,6 @@
module github.com/mekkanized/graphvent
go 1.22.0
go 1.20
require (
github.com/dgraph-io/badger/v3 v3.2103.5
@ -8,13 +8,12 @@ require (
github.com/google/uuid v1.3.0
github.com/graphql-go/graphql v0.8.1
github.com/rs/zerolog v1.29.1
golang.org/x/net v0.7.0
)
require (
filippo.io/edwards25519 v1.0.0 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgraph-io/badger/v4 v4.1.0 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
@ -25,12 +24,12 @@ require (
github.com/golang/protobuf v1.3.1 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/google/flatbuffers v1.12.1 // indirect
github.com/graphql-go/handler v0.2.3 // indirect
github.com/klauspost/compress v1.12.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/stretchr/testify v1.8.2 // indirect
go.opencensus.io v0.22.5 // indirect
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.6.0 // indirect
)

@ -1,8 +1,5 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
@ -17,17 +14,20 @@ github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3Ee
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger/v3 v3.2103.5 h1:ylPa6qzbjYRQMU6jokoj4wzcaweHylt//CH0AKt0akg=
github.com/dgraph-io/badger/v3 v3.2103.5/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw=
github.com/dgraph-io/badger/v4 v4.1.0 h1:E38jc0f+RATYrycSUf9LMv/t47XAy+3CApyYSq4APOQ=
github.com/dgraph-io/badger/v4 v4.1.0/go.mod h1:P50u28d39ibBRmIJuQC/NSdBOg46HnHw7al2SW5QRHg=
github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=
github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
@ -51,10 +51,15 @@ github.com/google/flatbuffers v1.12.1 h1:MVlul7pQNoDzWRLTw5imwYsl+usrS1TXG2H4jg6
github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/graph-gophers/graphql-go v1.5.0 h1:fDqblo50TEpD0LY7RXk/LFVYEVqo3+tXMNMPSVXA1yc=
github.com/graph-gophers/graphql-go v1.5.0/go.mod h1:YtmJZDLbF1YYNrlNAuiO5zAStUWc3XZT07iGsVqe1Os=
github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuMMgc=
github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ=
github.com/graphql-go/handler v0.2.3 h1:CANh8WPnl5M9uA25c2GBhPqJhE53Fg0Iue/fRNla71E=
github.com/graphql-go/handler v0.2.3/go.mod h1:leLF6RpV5uZMN1CdImAxuiayrYYhOk33bZciaUGaXeU=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@ -64,6 +69,8 @@ github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/looplab/fsm v1.0.1 h1:OEW0ORrIx095N/6lgoGkFkotqH6s7vaFPsgjLAaF5QU=
github.com/looplab/fsm v1.0.1/go.mod h1:PmD3fFvQEIsjMEfvZdrCDZ6y8VwKTwWNjlpEr6IKPO4=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
@ -71,17 +78,16 @@ github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
@ -90,29 +96,23 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI=
go.opentelemetry.io/otel/trace v1.6.3/go.mod h1:GNJQusJlUgZl9/TQBPKU/Y/ty+0iVB5fjhKeJGZPGFs=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@ -142,12 +142,11 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -171,6 +170,4 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

976
gql.go

File diff suppressed because it is too large Load Diff

@ -0,0 +1,172 @@
package graphvent
import (
"github.com/graphql-go/graphql"
"reflect"
"fmt"
)
func NewField(init func()*graphql.Field) *graphql.Field {
return init()
}
type Singleton[K graphql.Type] struct {
Type K
List *graphql.List
}
func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List)) *Singleton[K] {
val := init()
list := graphql.NewList(val)
if post_init != nil {
post_init(val, list)
}
return &Singleton[K]{
Type: val,
List: list,
}
}
func AddNodeInterfaceFields(i *graphql.Interface) {
i.AddFieldConfig("ID", &graphql.Field{
Type: graphql.String,
})
i.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
})
}
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, fmt.Errorf("Bad resolve in params context")
}
return resolve_context, nil
}
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
i := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Node",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_nodes := ctx.GQLContext.ValidNodes
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_nodes) {
if p_type == key {
return value
}
}
_, ok := p.Value.(Node)
if ok == true {
return ctx.GQLContext.BaseNodeType
}
return nil
},
Fields: graphql.Fields{},
})
AddNodeInterfaceFields(i)
return i
}, nil)
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Lockable",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_lockables := ctx.GQLContext.ValidLockables
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_lockables) {
if p_type == key {
return value
}
}
_, ok := p.Value.(*Node)
if ok == false {
return ctx.GQLContext.BaseLockableType
}
return nil
},
Fields: graphql.Fields{},
})
return gql_interface_lockable
}, func(lockable *graphql.Interface, lockable_list *graphql.List) {
lockable.AddFieldConfig("Requirements", &graphql.Field{
Type: lockable_list,
})
lockable.AddFieldConfig("Dependencies", &graphql.Field{
Type: lockable_list,
})
lockable.AddFieldConfig("Owner", &graphql.Field{
Type: lockable,
})
AddNodeInterfaceFields(lockable)
})
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Thread",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_threads := ctx.GQLContext.ValidThreads
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_threads) {
if p_type == key {
return value
}
}
node, ok := p.Value.(*Node)
if ok == false {
return nil
}
_, err = GetExt[*ThreadExt](node)
if err == nil {
return ctx.GQLContext.BaseThreadType
}
return nil
},
Fields: graphql.Fields{},
})
return gql_interface_thread
}, func(thread *graphql.Interface, thread_list *graphql.List) {
thread.AddFieldConfig("Children", &graphql.Field{
Type: thread_list,
})
thread.AddFieldConfig("Parent", &graphql.Field{
Type: thread,
})
thread.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
})
AddNodeInterfaceFields(thread)
})

@ -0,0 +1,114 @@
package graphvent
import (
"fmt"
"github.com/graphql-go/graphql"
)
var GQLMutationAbort = NewField(func()*graphql.Field {
gql_mutation_abort := &graphql.Field{
Type: GQLTypeSignal.Type,
Args: graphql.FieldConfigArgument{
"id": &graphql.ArgumentConfig{
Type: graphql.String,
},
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
id, err := ExtractID(p, "id")
if err != nil {
return nil, err
}
var node *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLMap(
NewACLInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) (error){
node, err = FindChild(context, ctx.User, ctx.Server, id)
if err != nil {
return err
}
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
return SendSignal(context, node, ctx.User, AbortSignal)
})
if err != nil {
return nil, err
}
return AbortSignal, nil
},
}
return gql_mutation_abort
})
var GQLMutationStartChild = NewField(func()*graphql.Field{
gql_mutation_start_child := &graphql.Field{
Type: GQLTypeSignal.Type,
Args: graphql.FieldConfigArgument{
"parent_id": &graphql.ArgumentConfig{
Type: graphql.String,
},
"child_id": &graphql.ArgumentConfig{
Type: graphql.String,
},
"action": &graphql.ArgumentConfig{
Type: graphql.String,
DefaultValue: "start",
},
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
parent_id, err := ExtractID(p, "parent_id")
if err != nil {
return nil, err
}
child_id, err := ExtractID(p, "child_id")
if err != nil {
return nil, err
}
action, err := ExtractParam[string](p, "action")
if err != nil {
return nil, err
}
var signal Signal
context := NewWriteContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLMap(
NewACLInfo(ctx.Server, []string{"children"}),
), func(context *StateContext) error {
parent, err := FindChild(context, ctx.User, ctx.Server, parent_id)
if err != nil {
return err
}
if parent == nil {
return fmt.Errorf("%s is not a child of %s", parent_id, ctx.Server.ID)
}
signal = NewStartChildSignal(child_id, action)
return SendSignal(context, ctx.User, parent, signal)
})
if err != nil {
return nil, err
}
// TODO: wait for the result of the signal to send back instead of just the signal
return signal, nil
},
}
return gql_mutation_start_child
})

@ -1,147 +0,0 @@
package graphvent
import (
"reflect"
"fmt"
"time"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
)
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
node, ok := p.Source.(NodeResult)
if ok == false {
return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source))
}
return node.NodeID, nil
}
func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
node, ok := p.Source.(NodeResult)
if ok == false {
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
}
return uint64(node.NodeType), nil
}
type FieldIndex struct {
Extension ExtType
Tag string
}
func GetFields(selection_set *ast.SelectionSet) []string {
names := []string{}
if selection_set == nil {
return names
}
for _, sel := range(selection_set.Selections) {
switch field := sel.(type) {
case *ast.Field:
if field.Name.Value == "ID" || field.Name.Value == "Type" {
continue
}
names = append(names, field.Name.Value)
case *ast.InlineFragment:
names = append(names, GetFields(field.SelectionSet)...)
}
}
return names
}
// Returns the fields that need to be resolved
func GetResolveFields(p graphql.ResolveParams) []string {
fields := []string{}
for _, field := range(p.Info.FieldASTs) {
fields = append(fields, GetFields(field.SelectionSet)...)
}
return fields
}
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
ctx, err := PrepResolve(p)
if err != nil {
return NodeResult{}, err
}
switch source := p.Source.(type) {
case *StatusSignal:
cached_node, cached := ctx.NodeCache[source.Source]
if cached {
for _, field_name := range(source.Fields) {
_, cached := cached_node.Data[field_name]
if cached {
delete(cached_node.Data, field_name)
}
}
ctx.NodeCache[source.Source] = cached_node
}
}
cache, node_cached := ctx.NodeCache[id]
fields := GetResolveFields(p)
var not_cached []string
if node_cached {
not_cached = []string{}
for _, field := range(fields) {
if node_cached {
_, field_cached := cache.Data[field]
if field_cached {
continue
}
}
not_cached = append(not_cached, field)
}
} else {
not_cached = fields
}
if (len(not_cached) == 0) && (node_cached == true) {
ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
return cache, nil
} else {
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
signal := NewReadSignal(not_cached)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
// TODO: TIMEOUT DURATION
err = ctx.Context.Send(ctx.Server, []Message{{
Node: id,
Signal: signal,
}})
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
}
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
ctx.Ext.FreeResponseChannel(signal.ID())
if err != nil {
return NodeResult{}, err
}
switch response := response.(type) {
case *ReadResultSignal:
if node_cached == false {
cache = NodeResult{
NodeID: id,
NodeType: response.NodeType,
Data: response.Fields,
}
} else {
for field_name, field_value := range(response.Fields) {
cache.Data[field_name] = field_value
}
}
ctx.NodeCache[id] = cache
return ctx.NodeCache[id], nil
default:
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
}
}
}

@ -0,0 +1,28 @@
package graphvent
import (
"github.com/graphql-go/graphql"
)
var GQLQuerySelf = &graphql.Field{
Type: GQLTypeBaseThread.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
return ctx.Server, nil
},
}
var GQLQueryUser = &graphql.Field{
Type: GQLTypeBaseNode.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
return ctx.User, nil
},
}

@ -0,0 +1,338 @@
package graphvent
import (
"fmt"
"reflect"
"github.com/graphql-go/graphql"
)
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, nil, fmt.Errorf("Bad resolve in params context")
}
node, ok := p.Source.(*Node)
if ok == false {
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
}
return node, resolve_context, nil
}
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
var zero K
arg_if, ok := p.Args[name]
if ok == false {
return zero, fmt.Errorf("No Arg of name %s", name)
}
arg, ok := arg_if.(K)
if ok == false {
return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero))
}
return arg, nil
}
func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, error) {
var zero K
arg_list, err := ExtractParam[[]interface{}](p, name)
if err != nil {
return nil, err
}
ret := make([]K, len(arg_list))
for i, val := range(arg_list) {
val_conv, ok := arg_list[i].(K)
if ok == false {
return nil, fmt.Errorf("Failed to cast arg %s[%d](%+v) to %+v", name, i, val, reflect.TypeOf(zero))
}
ret[i] = val_conv
}
return ret, nil
}
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
id_str, err := ExtractParam[string](p, name)
if err != nil {
return ZeroID, err
}
id, err := ParseID(id_str)
if err != nil {
return ZeroID, err
}
return id, nil
}
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
node, _, err := PrepResolve(p)
if err != nil {
return nil, err
}
return node.ID, nil
}
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
node, _, err := PrepResolve(p)
if err != nil {
return nil, err
}
return string(node.Type), nil
}
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
gql_ext, err := GetExt[*GQLExt](node)
if err != nil {
return nil, err
}
listen := ""
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = gql_ext.Listen
return nil
})
if err != nil {
return nil, err
}
return listen, nil
}
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var parent *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = thread_ext.Parent
return nil
})
if err != nil {
return nil, err
}
return parent, nil
}
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var state string
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
state = thread_ext.State
return nil
})
if err != nil {
return nil, err
}
return state, nil
}
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var children []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
children = thread_ext.ChildList()
return nil
})
if err != nil {
return nil, err
}
return children, nil
}
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var requirements []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = make([]*Node, len(lockable_ext.Requirements))
i := 0
for _, req := range(lockable_ext.Requirements) {
requirements[i] = req
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return requirements, nil
}
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var dependencies []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = make([]*Node, len(lockable_ext.Dependencies))
i := 0
for _, dep := range(lockable_ext.Dependencies) {
dependencies[i] = dep
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return dependencies, nil
}
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var owner *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = lockable_ext.Owner
return nil
})
if err != nil {
return nil, err
}
return owner, nil
}
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
group_ext, err := GetExt[*GroupExt](node)
if err != nil {
return nil, err
}
var members []*Node
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
members = make([]*Node, len(group_ext.Members))
i := 0
for _, member := range(group_ext.Members) {
members[i] = member
i += 1
}
return nil
})
if err != nil {
return nil, err
}
return members, nil
}
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
if signal, ok := p.Source.(Signal); ok {
return fn(signal, p)
}
return nil, fmt.Errorf("Failed to cast source to event")
}
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
return signal.Type(), nil
})
}
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
direction := signal.Direction()
if direction == Up {
return "up", nil
} else if direction == Down {
return "down", nil
} else if direction == Direct {
return "direct", nil
}
return nil, fmt.Errorf("Invalid direction: %+v", direction)
})
}
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
ser, err := signal.Serialize()
return string(ser), err
})
}

@ -0,0 +1,69 @@
package graphvent
import (
"github.com/graphql-go/graphql"
)
func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) {
return GQLSubscribeFn(p, false, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
return signal, nil
})
}
func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) {
return GQLSubscribeFn(p, true, func(ctx *Context, server *Node, ext *GQLExt, signal Signal, p graphql.ResolveParams)(interface{}, error) {
return server, nil
})
}
func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
_, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
c := make(chan interface{})
go func(c chan interface{}, ext *GQLExt, server *Node) {
ctx.Context.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START")
sig_c := ext.NewSubscriptionChannel(1)
if send_nil == true {
sig_c <- nil
}
for {
val, ok := <- sig_c
if ok == false {
return
}
ret, err := fn(ctx.Context, server, ext, val, p)
if err != nil {
ctx.Context.Log.Logf("gqlws", "type convertor error %s", err)
return
}
c <- ret
}
}(c, ctx.Ext, ctx.Server)
return c, nil
}
var GQLSubscriptionSelf = NewField(func()*graphql.Field{
gql_subscription_self := &graphql.Field{
Type: GQLTypeBaseThread.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil
},
Subscribe: GQLSubscribeSelf,
}
return gql_subscription_self
})
var GQLSubscriptionUpdate = NewField(func()*graphql.Field{
gql_subscription_update := &graphql.Field{
Type: GQLTypeSignal.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil
},
Subscribe: GQLSubscribeSignal,
}
return gql_subscription_update
})

@ -1,223 +1,152 @@
package graphvent
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"reflect"
"testing"
"time"
"github.com/google/uuid"
"golang.org/x/net/websocket"
"testing"
"time"
"errors"
"crypto/rand"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
)
func TestGQLSubscribe(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gql"})
func TestGQL(t * testing.T) {
ctx := logTestContext(t, []string{"test", "db"})
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
TestUserNodeType := NodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType})
fatalErr(t, err)
listener_ext := NewListenerExt(10)
u1 := NewNode(ctx, RandID(), TestUserNodeType)
u1_policy := NewPerNodePolicy(NodeActions{
u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"},
})
u1.Extensions[ACLExtType] = NewACLExt(nil)
u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
PerNodePolicyType: &u1_policy,
})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err)
ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
ListenerNodeType := NodeType("LISTENER")
err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ListenerExtType, LockableExtType})
fatalErr(t, err)
query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }"
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "Node: %s", n1.ID)
ctx.Log.Logf("test", "Query: %s", query)
sub_1 := GQLPayload{
Query: query,
}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("http://localhost:%d/gql", port)
ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port)
SubGQL := func(payload GQLPayload) {
config, err := websocket.NewConfig(ws_url, url)
fatalErr(t, err)
config.Protocol = append(config.Protocol, "graphql-ws")
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
ws, err := websocket.DialConfig(config)
fatalErr(t, err)
type payload_struct struct {
Token string `json:"token"`
}
init := struct{
ID uuid.UUID `json:"id"`
Type string `json:"type"`
}{
uuid.New(),
"connection_init",
}
ser, err := json.Marshal(&init)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
l1 := NewNode(ctx, RandID(), ListenerNodeType)
l1_policy := NewRequirementOfPolicy(NodeActions{
l1.ID: Actions{"signal.status"},
})
resp := make([]byte, 1024)
n, err := ws.Read(resp)
l1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
listener_ext := NewListenerExt(10)
l1.Extensions[ListenerExtType] = listener_ext
l1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
RequirementOfPolicyType: &l1_policy,
})
l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
var init_resp GQLWSMsg
err = json.Unmarshal(resp[:n], &init_resp)
fatalErr(t, err)
ctx.Log.Logf("test", "L1_ID: %s", l1.ID)
if init_resp.Type != "connection_ack" {
t.Fatal("Didn't receive connection_ack")
}
TestThreadNodeType := NodeType("TEST_THREAD")
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType})
fatalErr(t, err)
sub := GQLWSMsg{
ID: uuid.New().String(),
Type: "subscribe",
Payload: sub_1,
}
t1 := NewNode(ctx, RandID(), TestThreadNodeType)
t1_policy := NewParentOfPolicy(NodeActions{
t1.ID: Actions{"signal.abort", "state.write"},
})
t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ParentOfPolicyType: &t1_policy,
})
t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil)
fatalErr(t, err)
t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
ser, err = json.Marshal(&sub)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB1: %s", resp[:n])
TestGQLNodeType := NodeType("TEST_GQL")
err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType})
fatalErr(t, err)
lock_id, err := LockLockable(ctx, gql)
fatalErr(t, err)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
response, _, err := WaitForResponse(listener_ext.Chan, 100*time.Millisecond, lock_id)
fatalErr(t, err)
gql := NewNode(ctx, RandID(), TestGQLNodeType)
gql_policy := NewChildOfPolicy(NodeActions{
gql.ID: Actions{"signal.status"},
})
gql.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ChildOfPolicyType: &gql_policy,
})
gql.Extensions[GroupExtType] = NewGroupExt(nil)
gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil)
gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil)
fatalErr(t, err)
gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
switch response.(type) {
case *SuccessSignal:
ctx.Log.Logf("test", "Locked %s", gql.ID)
default:
t.Errorf("Unexpected lock response: %s", response)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
info := ParentInfo{true, "start", "restore"}
context := NewWriteContext(ctx)
err = UpdateStates(context, u1, NewACLInfo(gql, []string{"users"}), func(context *StateContext) error {
err := LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{
ParentInfoType: &info,
}})
if err != nil {
return err
}
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB2: %s", resp[:n])
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB3: %s", resp[:n])
// TODO: check that there are no more messages sent to ws within a timeout
}
SubGQL(sub_1)
}
func TestGQLQuery(t *testing.T) {
ctx := logTestContext(t, []string{"test", "lockable"})
n1_listener := NewListenerExt(10)
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil), n1_listener)
return LinkLockables(context, u1, l1, []*Node{gql})
})
fatalErr(t, err)
gql_listener := NewListenerExt(10)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
context = NewReadContext(ctx)
err = SendSignal(context, gql, gql, NewStatusSignal("child_linked", t1.ID))
fatalErr(t, err)
gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener)
context = NewReadContext(ctx)
err = SendSignal(context, gql, gql, AbortSignal)
fatalErr(t, err)
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID)
skipVerifyTransport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: skipVerifyTransport}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("http://localhost:%d/gql", port)
req_1 := GQLPayload{
Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { LockableState } } }",
Variables: map[string]interface{}{
"id": n1.ID.String(),
},
}
req_2 := GQLPayload{
Query: "query Self { Self { ID, Type, ... on Lockable { LockableState, Requirements { Key { ID ... on Lockable { LockableState } } } } } }",
}
SendGQL := func(payload GQLPayload) []byte {
ser, err := json.MarshalIndent(&payload, "", " ")
fatalErr(t, err)
req_data := bytes.NewBuffer(ser)
req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err)
resp, err := client.Do(req)
fatalErr(t, err)
body, err := io.ReadAll(resp.Body)
err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, ThreadAbortedError) == false {
fatalErr(t, err)
resp.Body.Close()
return body
}
resp_1 := SendGQL(req_1)
ctx.Log.Logf("test", "RESP_1: %s", resp_1)
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
lock_id, err := LockLockable(ctx, n1)
fatalErr(t, err)
response, _, err := WaitForResponse(n1_listener.Chan, 100*time.Millisecond, lock_id)
fatalErr(t, err)
switch response := response.(type) {
case *SuccessSignal:
default:
t.Fatalf("Wrong response: %s", reflect.TypeOf(response))
}
resp_3 := SendGQL(req_1)
ctx.Log.Logf("test", "RESP_3: %s", resp_3)
resp_4 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_4: %s", resp_4)
}
func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node", "serialize"})
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err)
listener_ext := NewListenerExt(10)
context = NewReadContext(ctx)
err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1)
ctx.Log.Logf("test", "\n%s\n\n", ser2)
return err
})
gql, err := ctx.NewNode(nil, "Node", gql_ext, listener_ext)
fatalErr(t, err)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Stop()
// Clear all loaded nodes from the context so it loads them from the database
ctx.Nodes = NodeMap{}
gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err)
context = NewReadContext(ctx)
err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize()
lockable_ext, err := GetExt[*LockableExt](gql_loaded)
if err != nil {
return err
}
ctx.Log.Logf("test", "\n%s\n\n", ser)
dependency := lockable_ext.Dependencies[l1.ID]
listener_ext, err = GetExt[*ListenerExt](dependency)
if err != nil {
return err
}
SendSignal(context, gql_loaded, gql_loaded, StopSignal)
return err
})
gql_loaded, err := ctx.GetNode(gql.ID)
err = ThreadLoop(ctx, gql_loaded, "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
listener_ext, err = GetExt[ListenerExt](gql_loaded)
fatalErr(t, err)
}

@ -0,0 +1,148 @@
package graphvent
import (
"github.com/graphql-go/graphql"
)
func AddNodeFields(obj *graphql.Object) {
obj.AddFieldConfig("ID", &graphql.Field{
Type: graphql.String,
Resolve: GQLNodeID,
})
obj.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
Resolve: GQLNodeTypeHash,
})
}
func AddLockableFields(obj *graphql.Object) {
AddNodeFields(obj)
obj.AddFieldConfig("Requirements", &graphql.Field{
Type: GQLInterfaceLockable.List,
Resolve: GQLLockableRequirements,
})
obj.AddFieldConfig("Owner", &graphql.Field{
Type: GQLInterfaceLockable.Type,
Resolve: GQLLockableOwner,
})
obj.AddFieldConfig("Dependencies", &graphql.Field{
Type: GQLInterfaceLockable.List,
Resolve: GQLLockableDependencies,
})
}
func AddThreadFields(obj *graphql.Object) {
AddNodeFields(obj)
obj.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
Resolve: GQLThreadState,
})
obj.AddFieldConfig("Children", &graphql.Field{
Type: GQLInterfaceThread.List,
Resolve: GQLThreadChildren,
})
obj.AddFieldConfig("Parent", &graphql.Field{
Type: GQLInterfaceThread.Type,
Resolve: GQLThreadParent,
})
}
var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleThread",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceThread.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(*Node)
if ok == false {
return false
}
_, err := GetExt[*ThreadExt](node)
return err == nil
},
Fields: graphql.Fields{},
})
AddThreadFields(gql_type_simple_thread)
return gql_type_simple_thread
}, nil)
var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleLockable",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(*Node)
if ok == false {
return false
}
_, err := GetExt[*LockableExt](node)
return err == nil
},
Fields: graphql.Fields{},
})
AddLockableFields(gql_type_simple_lockable)
return gql_type_simple_lockable
}, nil)
var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
object := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleNode",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(*Node)
return ok
},
Fields: graphql.Fields{},
})
AddNodeFields(object)
return object
}, nil)
var GQLTypeSignal = NewSingleton(func() *graphql.Object {
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
Name: "Signal",
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Signal)
return ok
},
Fields: graphql.Fields{},
})
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalType,
})
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalDirection,
})
gql_type_signal.AddFieldConfig("String", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalString,
})
return gql_type_signal
}, nil)

@ -2,34 +2,76 @@ package graphvent
import (
"testing"
"runtime/debug"
"fmt"
"time"
"runtime/pprof"
"runtime/debug"
"os"
badger "github.com/dgraph-io/badger/v3"
)
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) {
listener_extension := NewListenerExt(buffer)
listener, err := ctx.NewNode(nil, "LockableNode", nil, listener_extension, NewLockableExt(nil))
type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond
func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener:
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == "status" {
sig, ok := signal.(StatusSignal)
if ok == true {
if sig.Status == status {
return signal
}
ctx.Log.Logf("test", "Different status received: %s", sig.Status)
} else {
ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal)
}
}
case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(str)
return nil
}
}
return nil
}
return listener, listener_extension, err
func (t * GraphTester) CheckForNone(listener chan Signal, str string) {
timeout := time.After(listner_timeout)
select {
case sig := <- listener:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(fmt.Sprintf("%s : %+v", str, sig))
case <-timeout:
}
}
func logTestContext(t * testing.T, components []string) *Context {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true).WithSyncWrites(true))
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatal(err)
}
ctx, err := NewContext(&BadgerDB{
DB: db,
}, NewConsoleLogger(components))
ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err)
return ctx
}
func testContext(t * testing.T) * Context {
return logTestContext(t, []string{})
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatal(err)
}
ctx, err := NewContext(db, NewConsoleLogger([]string{}))
fatalErr(t, err)
return ctx
}
func fatalErr(t * testing.T, err error) {
@ -38,16 +80,3 @@ func fatalErr(t * testing.T, err error) {
t.Fatal(err)
}
}
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err)
messages := []Message{{destination.ID, signal}}
fatalErr(t, ctx.Send(source, messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response, signals
}

@ -1,66 +0,0 @@
package graphvent
import (
"reflect"
)
// A Listener extension provides a channel that can receive signals on a different thread
type ListenerExt struct {
Buffer int `gv:"buffer"`
Chan chan Signal
}
type LoadedSignal struct {
SignalHeader
}
func NewLoadedSignal() *LoadedSignal {
return &LoadedSignal{
SignalHeader: NewSignalHeader(),
}
}
type UnloadedSignal struct {
SignalHeader
}
func NewUnloadedSignal() *UnloadedSignal {
return &UnloadedSignal{
SignalHeader: NewSignalHeader(),
}
}
func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
ext.Chan = make(chan Signal, ext.Buffer)
ext.Chan <- NewLoadedSignal()
return nil
}
func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
ext.Chan <- NewUnloadedSignal()
close(ext.Chan)
}
// Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{
Buffer: buffer,
Chan: make(chan Signal, buffer),
}
}
// Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
select {
case ext.Chan <- signal:
default:
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
}
switch sig := signal.(type) {
case *StatusSignal:
ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields)
}
return nil, nil
}

@ -1,411 +1,596 @@
package graphvent
import (
"github.com/google/uuid"
"fmt"
"encoding/json"
)
type ReqState byte
const (
Unlocked = ReqState(0)
Unlocking = ReqState(1)
Locked = ReqState(2)
Locking = ReqState(3)
AbortingLock = ReqState(4)
)
type ListenerExt struct {
Buffer int
Chan chan Signal
}
func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{
Buffer: buffer,
Chan: make(chan Signal, buffer),
}
}
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
var j int
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
return NewListenerExt(j), nil
}
var ReqStateStrings = map[ReqState]string {
Unlocked: "Unlocked",
Unlocking: "Unlocking",
Locked: "Locked",
Locking: "Locking",
AbortingLock: "AbortingLock",
const ListenerExtType = ExtType("LISTENER")
func (listener *ListenerExt) Type() ExtType {
return ListenerExtType
}
func (state ReqState) String() string {
str, mapped := ReqStateStrings[state]
if mapped == false {
return "UNKNOWN_REQSTATE"
} else {
return str
func (ext *ListenerExt) Process(context *StateContext, node *Node, signal Signal) error {
select {
case ext.Chan <- signal:
default:
return fmt.Errorf("LISTENER_OVERFLOW - %+v", signal)
}
return nil
}
func (ext *ListenerExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext.Buffer, "", " ")
}
type LockableExt struct{
State ReqState `gv:"state"`
ReqID *uuid.UUID `gv:"req_id"`
Owner *NodeID `gv:"owner"`
PendingOwner *NodeID `gv:"pending_owner"`
Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
type LockableExt struct {
Owner *Node
Requirements map[NodeID]*Node
Dependencies map[NodeID]*Node
LocksHeld map[NodeID]*Node
}
Locked map[NodeID]any
Unlocked map[NodeID]any
const LockableExtType = ExtType("LOCKABLE")
func (ext *LockableExt) Type() ExtType {
return LockableExtType
}
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
type LockableExtJSON struct {
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}
func NewLockableExt(requirements []NodeID) *LockableExt {
var reqs map[NodeID]ReqState = nil
var unlocked map[NodeID]any = map[NodeID]any{}
func (ext *LockableExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&LockableExtJSON{
Owner: SaveNode(ext.Owner),
Requirements: SaveNodeList(ext.Requirements),
Dependencies: SaveNodeList(ext.Dependencies),
LocksHeld: SaveNodeMap(ext.LocksHeld),
}, "", " ")
}
if len(requirements) != 0 {
reqs = map[NodeID]ReqState{}
for _, req := range(requirements) {
reqs[req] = Unlocked
unlocked[req] = nil
}
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
if requirements == nil {
requirements = NodeMap{}
}
return &LockableExt{
State: Unlocked,
Owner: nil,
PendingOwner: nil,
Requirements: reqs,
Waiting: WaitMap{},
if dependencies == nil {
dependencies = NodeMap{}
}
Locked: map[NodeID]any{},
Unlocked: unlocked,
if locks_held == nil {
locks_held = NodeMap{}
}
}
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
signal := NewUnlockSignal()
messages := []Message{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages)
return &LockableExt{
Owner: owner,
Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held,
}
}
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
signal := NewLockSignal()
messages := []Message{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages)
}
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var j LockableExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
func (ext *LockableExt) Load(ctx *Context, node *Node) error {
ext.Locked = map[NodeID]any{}
ext.Unlocked = map[NodeID]any{}
ctx.Log.Logf("db", "DB_LOADING_LOCKABLE_EXT_JSON: %+v", j)
for id, state := range(ext.Requirements) {
if state == Unlocked {
ext.Unlocked[id] = nil
} else if state == Locked {
ext.Locked[id] = nil
}
owner, err := RestoreNode(ctx, j.Owner)
if err != nil {
return nil, err
}
return nil
}
func (ext *LockableExt) Unload(ctx *Context, node *Node) {
return
requirements, err := RestoreNodeList(ctx, j.Requirements)
if err != nil {
return nil, err
}
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
if err != nil {
return nil, err
}
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
if err != nil {
return nil, err
}
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
}
// Handle link signal by adding/removing the requested NodeID
// returns an error if the node is not unlocked
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
switch ext.State {
case Unlocked:
switch signal.Action {
case "add":
_, exists := ext.Requirements[signal.NodeID]
if exists == true {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_requirement")})
} else {
if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{}
func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
var err error
switch signal.Direction() {
case Up:
err = UseStates(context, node,
NewACLInfo(node, []string{"dependencies", "owner"}), func(context *StateContext) error {
owner_sent := false
for _, dependency := range(ext.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
SendSignal(context, dependency, node, signal)
if ext.Owner != nil {
if dependency.ID == ext.Owner.ID {
owner_sent = true
}
}
ext.Requirements[signal.NodeID] = Unlocked
changes = append(changes, "requirements")
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
}
case "remove":
_, exists := ext.Requirements[signal.NodeID]
if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else {
delete(ext.Requirements, signal.NodeID)
changes = append(changes, "requirements")
messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
if ext.Owner != nil && owner_sent == false {
if ext.Owner.ID != node.ID {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
return SendSignal(context, ext.Owner, node, signal)
}
}
default:
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unknown_action")})
}
return nil
})
case Down:
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(ext.Requirements) {
err := SendSignal(context, requirement, node, signal)
if err != nil {
return err
}
}
return nil
})
case Direct:
err = nil
default:
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocked: %s", ext.State)})
err = fmt.Errorf("invalid signal direction %d", signal.Direction())
}
if err != nil {
return err
}
return nil
}
return messages, changes
func (ext *LockableExt) RecordUnlock(node *Node) *Node {
last_owner, exists := ext.LocksHeld[node.ID]
if exists == false {
panic("Attempted to take a get the original lock holder of a lockable we don't own")
}
delete(ext.LocksHeld, node.ID)
return last_owner
}
// Handle an UnlockSignal by either transitioning to Unlocked state,
// sending unlock signals to requirements, or returning an error signal
func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source NodeID, signal *UnlockSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
func (ext *LockableExt) RecordLock(node *Node, last_owner *Node) {
_, exists := ext.LocksHeld[node.ID]
if exists == true {
panic("Attempted to lock a lockable we're already holding(lock cycle)")
}
ext.LocksHeld[node.ID] = last_owner
}
switch ext.State {
case Locked:
if source != *ext.Owner {
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_owner")})
} else {
if len(ext.Requirements) == 0 {
changes = append(changes, "state", "owner", "pending_owner")
// Removes requirement as a requirement from lockable
func UnlinkLockables(context *StateContext, princ *Node, lockable *Node, requirement *Node) error {
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
requirement_ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
return UpdateStates(context, princ, ACLMap{
lockable.ID: ACLInfo{Node: lockable, Resources: []string{"requirements"}},
requirement.ID: ACLInfo{Node: requirement, Resources: []string{"dependencies"}},
}, func(context *StateContext) error {
var found *Node = nil
for _, req := range(lockable_ext.Requirements) {
if requirement.ID == req.ID {
found = req
break
}
}
ext.Owner = nil
if found == nil {
return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID, lockable.ID)
}
ext.PendingOwner = nil
delete(requirement_ext.Dependencies, lockable.ID)
delete(lockable_ext.Requirements, requirement.ID)
ext.State = Unlocked
return nil
})
}
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
} else {
changes = append(changes, "state", "waiting", "requirements", "pending_owner")
// Link requirements as requirements to lockable
func LinkLockables(context *StateContext, princ *Node, lockable *Node, requirements []*Node) error {
if lockable == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements")
}
ext.PendingOwner = nil
if len(requirements) == 0 {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link no lockables in call")
}
ext.ReqID = &signal.Id
lockable_ext, err := GetExt[*LockableExt](lockable)
if err != nil {
return err
}
ext.State = Unlocking
for id := range(ext.Requirements) {
unlock_signal := NewUnlockSignal()
req_exts := map[NodeID]*LockableExt{}
for _, requirement := range(requirements) {
if requirement == nil {
return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement")
}
ext.Waiting[unlock_signal.Id] = id
ext.Requirements[id] = Unlocking
if lockable.ID == requirement.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID)
}
messages = append(messages, Message{id, unlock_signal})
}
}
_, exists := req_exts[requirement.ID]
if exists == true {
return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s twice", requirement.ID)
}
default:
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_locked")})
ext, err := GetExt[*LockableExt](requirement)
if err != nil {
return err
}
req_exts[requirement.ID] = ext
}
return messages, changes
}
// Handle a LockSignal by either transitioning to a locked state,
// sending lock signals to requirements, or returning an error signal
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
return UpdateStates(context, princ, NewACLMap(
NewACLInfo(lockable, []string{"requirements"}),
ACLList(requirements, []string{"dependencies"}),
), func(context *StateContext) error {
// Check that all the requirements can be added
// If the lockable is already locked, need to lock this resource as well before we can add it
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
for _, req := range(requirements) {
if req.ID == requirement.ID {
continue
}
switch ext.State {
case Unlocked:
if len(ext.Requirements) == 0 {
changes = append(changes, "state", "owner", "pending_owner")
is_req, err := checkIfRequirement(context, req.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot add the same dependency", req.ID, requirement.ID)
ext.Owner = &source
}
}
ext.PendingOwner = &source
is_req, err := checkIfRequirement(context, lockable.ID, requirement_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID, lockable.ID)
}
ext.State = Locked
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
} else {
changes = append(changes, "state", "requirements", "waiting", "pending_owner")
is_req, err = checkIfRequirement(context, requirement.ID, lockable_ext)
if err != nil {
return err
} else if is_req {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID, requirement.ID)
}
ext.PendingOwner = &source
if lockable_ext.Owner == nil {
// If the new owner isn't locked, we can add the requirement
} else if requirement_ext.Owner == nil {
// if the new requirement isn't already locked but the owner is, the requirement needs to be locked first
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID, requirement.ID)
} else {
// If the new requirement is already locked and the owner is already locked, their owners need to match
if requirement_ext.Owner.ID != lockable_ext.Owner.ID {
return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID, lockable.ID)
}
}
}
// Update the states of the requirements
for _, requirement := range(requirements) {
requirement_ext := req_exts[requirement.ID]
requirement_ext.Dependencies[lockable.ID] = lockable
lockable_ext.Requirements[lockable.ID] = requirement
context.Graph.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID, lockable.ID)
}
ext.ReqID = &signal.Id
// Return no error
return nil
})
}
ext.State = Locking
for id := range(ext.Requirements) {
lock_signal := NewLockSignal()
func checkIfRequirement(context *StateContext, id NodeID, cur *LockableExt) (bool, error) {
for _, req := range(cur.Requirements) {
if req.ID == id {
return true, nil
}
ext.Waiting[lock_signal.Id] = id
ext.Requirements[id] = Locking
req_ext, err := GetExt[*LockableExt](req)
if err != nil {
return false, err
}
messages = append(messages, Message{id, lock_signal})
}
var is_req bool
err = UpdateStates(context, req, NewACLInfo(req, []string{"requirements"}), func(context *StateContext) error {
is_req, err = checkIfRequirement(context, id, req_ext)
return err
})
if err != nil {
return false, err
}
if is_req == true {
return true, nil
}
default:
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_unlocked: %s", ext.State)})
}
return messages, changes
return false, nil
}
// Handle an error signal by aborting the lock, or retrying the unlock
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
// Lock nodes in the to_lock slice with new_owner, does not modify any states if returning an error
// Assumes that new_owner will be written to after returning, even though it doesn't get locked during the call
func LockLockables(context *StateContext, to_lock NodeMap, new_owner *Node) error {
if to_lock == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: no map provided")
}
id, waiting := ext.Waiting[signal.ReqID]
if waiting == true {
delete(ext.Waiting, signal.ReqID)
changes = append(changes, "waiting")
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_lock) {
var err error
if l == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: Can not lock nil")
}
switch ext.State {
case Locking:
changes = append(changes, "state", "requirements")
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
ext.Requirements[id] = Unlocked
if new_owner == nil {
return fmt.Errorf("LOCKABLE_LOCK_ERR: nil cannot hold locks")
}
unlocked := 0
for req_id, req_state := range(ext.Requirements) {
// Unlock locked requirements, and count unlocked requirements
switch req_state {
case Locked:
unlock_signal := NewUnlockSignal()
new_owner_ext, err := GetExt[*LockableExt](new_owner)
if err != nil {
return err
}
ext.Waiting[unlock_signal.Id] = req_id
ext.Requirements[req_id] = Unlocking
// Called with no requirements to lock, success
if len(to_lock) == 0 {
return nil
}
messages = append(messages, Message{req_id, unlock_signal})
case Unlocked:
unlocked += 1
return UpdateStates(context, new_owner, NewACLMap(
ACLListM(to_lock, []string{"lock"}),
NewACLInfo(new_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID, new_owner.ID)
// If req is alreay locked, check that we can pass the lock
if req_ext.Owner != nil {
owner := req_ext.Owner
if owner.ID == new_owner.ID {
continue
} else {
err := UpdateStates(context, new_owner, NewACLInfo(owner, []string{"take_lock"}), func(context *StateContext)(error){
return LockLockables(context, req_ext.Requirements, req)
})
if err != nil {
return err
}
}
} else {
err := LockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
}
}
if unlocked == len(ext.Requirements) {
changes = append(changes, "owner", "state")
ext.State = Unlocked
ext.Owner = nil
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_lock) {
req_ext := req_exts[req.ID]
old_owner := req_ext.Owner
// If the lockable was previously unowned, update the state
if old_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s locked %s", new_owner.ID, req.ID)
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
// Otherwise if the new owner already owns it, no need to update state
} else if old_owner.ID == new_owner.ID {
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s already owns %s", new_owner.ID, req.ID)
// Otherwise update the state
} else {
changes = append(changes, "state")
ext.State = AbortingLock
req_ext.Owner = new_owner
new_owner_ext.RecordLock(req, old_owner)
context.Graph.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", new_owner.ID, req.ID, old_owner.ID)
}
}
return nil
})
case Unlocking:
unlock_signal := NewUnlockSignal()
ext.Waiting[unlock_signal.Id] = id
messages = append(messages, Message{id, unlock_signal})
case AbortingLock:
req_state := ext.Requirements[id]
// Mark failed lock as Unlocked, or retry unlock
switch req_state {
case Locking:
ext.Requirements[id] = Unlocked
// Check if all requirements unlocked now
unlocked := 0
for _, req_state := range(ext.Requirements) {
if req_state == Unlocked {
unlocked += 1
}
}
}
if unlocked == len(ext.Requirements) {
changes = append(changes, "owner", "state")
ext.State = Unlocked
ext.Owner = nil
}
case Unlocking:
// Handle error for unlocking requirement while unlocking by retrying unlock
unlock_signal := NewUnlockSignal()
ext.Waiting[unlock_signal.Id] = id
messages = append(messages, Message{id, unlock_signal})
}
}
func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) error {
if to_unlock == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided")
}
return messages, changes
}
req_exts := map[NodeID]*LockableExt{}
for _, l := range(to_unlock) {
if l == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: Can not unlock nil")
}
// Handle a success signal by checking if all requirements have been locked/unlocked
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
var err error
req_exts[l.ID], err = GetExt[*LockableExt](l)
if err != nil {
return err
}
}
id, waiting := ext.Waiting[signal.ReqID]
if waiting == true {
delete(ext.Waiting, signal.ReqID)
changes = append(changes, "waiting")
if old_owner == nil {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: nil cannot hold locks")
}
switch ext.State {
case Locking:
ext.Requirements[id] = Locked
ext.Locked[id] = nil
delete(ext.Unlocked, id)
old_owner_ext, err := GetExt[*LockableExt](old_owner)
if err != nil {
return err
}
if len(ext.Locked) == len(ext.Requirements) {
ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked))
changes = append(changes, "state", "owner", "req_id")
ext.State = Locked
ext.Owner = ext.PendingOwner
// Called with no requirements to unlock, success
if len(to_unlock) == 0 {
return nil
}
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
ext.ReqID = nil
return UpdateStates(context, old_owner, NewACLMap(
ACLListM(to_unlock, []string{"lock"}),
NewACLInfo(old_owner, nil),
), func(context *StateContext) error {
// First loop is to check that the states can be locked, and locks all requirements
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID, old_owner.ID)
// Check if the owner is correct
if req_ext.Owner != nil {
if req_ext.Owner.ID != old_owner.ID {
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID, old_owner.ID)
}
} else {
ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements))
return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID)
}
case AbortingLock:
req_state := ext.Requirements[id]
switch req_state {
case Locking:
ext.Requirements[id] = Unlocking
unlock_signal := NewUnlockSignal()
ext.Waiting[unlock_signal.Id] = id
messages = append(messages, Message{id, unlock_signal})
case Unlocking:
ext.Requirements[id] = Unlocked
ext.Unlocked[id] = nil
delete(ext.Locked, id)
unlocked := 0
for _, req_state := range(ext.Requirements) {
switch req_state {
case Unlocked:
unlocked += 1
}
}
if unlocked == len(ext.Requirements) {
changes = append(changes, "state", "pending_owner", "req_id")
err := UnlockLockables(context, req_ext.Requirements, req)
if err != nil {
return err
}
}
messages = append(messages, Message{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)})
ext.State = Unlocked
ext.ReqID = nil
ext.PendingOwner = nil
}
// At this point state modification will be started, so no errors can be returned
for _, req := range(to_unlock) {
req_ext := req_exts[req.ID]
new_owner := old_owner_ext.RecordUnlock(req)
req_ext.Owner = new_owner
if new_owner == nil {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s unlocked %s", old_owner.ID, req.ID)
} else {
context.Graph.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", old_owner.ID, req.ID, new_owner.ID)
}
}
return nil
})
}
func SaveNode(node *Node) string {
str := ""
if node != nil {
str = node.ID.String()
}
return str
}
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
if id_str == "" {
return nil, nil
}
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
return LoadNode(ctx, id)
}
case Unlocking:
ext.Requirements[id] = Unlocked
ext.Unlocked[id] = Unlocked
delete(ext.Locked, id)
func SaveNodeMap(nodes NodeMap) map[string]string {
m := map[string]string{}
for id, node := range(nodes) {
m[id.String()] = SaveNode(node)
}
return m
}
if len(ext.Unlocked) == len(ext.Requirements) {
changes = append(changes, "state", "owner", "req_id")
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
nodes := NodeMap{}
for id_str_1, id_str_2 := range(ids) {
id_1, err := ParseID(id_str_1)
if err != nil {
return nil, err
}
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
ext.State = Unlocked
ext.ReqID = nil
ext.Owner = nil
node_1, err := LoadNode(ctx, id_1)
if err != nil {
return nil, err
}
var node_2 *Node = nil
if id_str_2 != "" {
id_2, err := ParseID(id_str_2)
if err != nil {
return nil, err
}
node_2, err = LoadNode(ctx, id_2)
if err != nil {
return nil, err
}
}
nodes[node_1.ID] = node_2
}
return messages, changes
return nodes, nil
}
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
func SaveNodeList(nodes NodeMap) []string {
ids := make([]string, len(nodes))
i := 0
for id, _ := range(nodes) {
ids[i] = id.String()
i += 1
}
switch sig := signal.(type) {
case *StatusSignal:
// Forward StatusSignals up to the owner(unless that would be a cycle)
if ext.Owner != nil {
if *ext.Owner != node.ID {
messages = append(messages, Message{*ext.Owner, signal})
}
return ids
}
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
nodes := NodeMap{}
for _, id_str := range(ids) {
node, err := RestoreNode(ctx, id_str)
if err != nil {
return nil, err
}
case *LinkSignal:
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
case *LockSignal:
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
case *UnlockSignal:
messages, changes = ext.HandleUnlockSignal(ctx, node, source, sig)
case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case *SuccessSignal:
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
}
return messages, changes
nodes[node.ID] = node
}
return nodes, nil
}

@ -1,148 +0,0 @@
package graphvent
import (
"testing"
"time"
)
func TestLink(t *testing.T) {
ctx := logTestContext(t, []string{"lockable", "listener"})
l2_listener := NewListenerExt(10)
l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
fatalErr(t, err)
l1_lockable := NewLockableExt(nil)
l1_listener := NewListenerExt(10)
l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
fatalErr(t, err)
link_signal := NewLinkSignal("add", l2.ID)
msgs := []Message{{l1.ID, link_signal}}
err = ctx.Send(l1, msgs)
fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
fatalErr(t, err)
state, exists := l1_lockable.Requirements[l2.ID]
if exists == false {
t.Fatal("l2 not in l1 requirements")
} else if state != Unlocked {
t.Fatalf("l2 in bad requirement state in l1: %+v", state)
}
unlink_signal := NewLinkSignal("remove", l2.ID)
msgs = []Message{{l1.ID, unlink_signal}}
err = ctx.Send(l1, msgs)
fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
fatalErr(t, err)
}
func Test10Lock(t *testing.T) {
testLockN(t, 10)
}
func Test100Lock(t *testing.T) {
testLockN(t, 100)
}
func Test1000Lock(t *testing.T) {
testLockN(t, 1000)
}
func Test10000Lock(t *testing.T) {
testLockN(t, 10000)
}
func testLockN(t *testing.T, n int) {
ctx := logTestContext(t, []string{"test"})
NewLockable := func()(*Node) {
l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
fatalErr(t, err)
return l
}
reqs := make([]NodeID, n)
for i := range(reqs) {
new_lockable := NewLockable()
reqs[i] = new_lockable.ID
}
ctx.Log.Logf("test", "CREATED_%d", n)
listener := NewListenerExt(50000)
node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err)
ctx.Log.Logf("test", "CREATED_LISTENER")
lock_id, err := LockLockable(ctx, node)
fatalErr(t, err)
response, _, err := WaitForResponse(listener.Chan, time.Second*60, lock_id)
fatalErr(t, err)
switch resp := response.(type) {
case *SuccessSignal:
default:
t.Fatalf("Unexpected response to lock - %s", resp)
}
ctx.Log.Logf("test", "LOCKED_%d", n)
}
func TestLock(t *testing.T) {
ctx := logTestContext(t, []string{"test", "lockable"})
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
listener := NewListenerExt(10000)
l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err)
return l, listener
}
l2, _ := NewLockable(nil)
l3, _ := NewLockable(nil)
l4, _ := NewLockable(nil)
l5, _ := NewLockable(nil)
l0, l0_listener := NewLockable([]NodeID{l5.ID})
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
ctx.Log.Logf("test", "l0: %s", l0.ID)
ctx.Log.Logf("test", "l1: %s", l1.ID)
ctx.Log.Logf("test", "l2: %s", l2.ID)
ctx.Log.Logf("test", "l3: %s", l3.ID)
ctx.Log.Logf("test", "l4: %s", l4.ID)
ctx.Log.Logf("test", "l5: %s", l5.ID)
ctx.Log.Logf("test", "locking l0")
id_1, err := LockLockable(ctx, l0)
fatalErr(t, err)
response, _, err := WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1)
fatalErr(t, err)
ctx.Log.Logf("test", "l0 lock: %+v", response)
ctx.Log.Logf("test", "locking l1")
id_2, err := LockLockable(ctx, l1)
fatalErr(t, err)
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2)
fatalErr(t, err)
ctx.Log.Logf("test", "l1 lock: %+v", response)
ctx.Log.Logf("test", "unlocking l0")
id_3, err := UnlockLockable(ctx, l0)
fatalErr(t, err)
response, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3)
fatalErr(t, err)
ctx.Log.Logf("test", "l0 unlock: %+v", response)
ctx.Log.Logf("test", "locking l1")
id_4, err := LockLockable(ctx, l1)
fatalErr(t, err)
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
fatalErr(t, err)
ctx.Log.Logf("test", "l1 lock: %+v", response)
}

@ -50,7 +50,7 @@ func (logger * ConsoleLogger) SetComponents(components []string) error {
return false
}
for c := range(logger.loggers) {
for c, _ := range(logger.loggers) {
if component_enabled(c) == false {
delete(logger.loggers, c)
}

@ -1,68 +0,0 @@
package graphvent
type Message struct {
Node NodeID
Signal Signal
}
type MessageQueue struct {
out chan<- Message
in <-chan Message
buffer []Message
write_cursor int
read_cursor int
}
func (queue *MessageQueue) ProcessIncoming(message Message) {
if (queue.write_cursor + 1) == queue.read_cursor || ((queue.write_cursor + 1) == len(queue.buffer) && queue.read_cursor == 0) {
new_buffer := make([]Message, len(queue.buffer) * 2)
copy(new_buffer, queue.buffer[queue.read_cursor:])
first_chunk := len(queue.buffer) - queue.read_cursor
copy(new_buffer[first_chunk:], queue.buffer[0:queue.write_cursor])
queue.write_cursor = len(queue.buffer) - 1
queue.read_cursor = 0
queue.buffer = new_buffer
}
queue.buffer[queue.write_cursor] = message
queue.write_cursor += 1
if queue.write_cursor >= len(queue.buffer) {
queue.write_cursor = 0
}
}
func NewMessageQueue(initial int) (chan<- Message, <-chan Message) {
in := make(chan Message, 0)
out := make(chan Message, 0)
queue := MessageQueue{
out: out,
in: in,
buffer: make([]Message, initial),
write_cursor: 0,
read_cursor: 0,
}
go func(queue *MessageQueue) {
for true {
if queue.write_cursor != queue.read_cursor {
select {
case incoming := <-queue.in:
queue.ProcessIncoming(incoming)
case queue.out <- queue.buffer[queue.read_cursor]:
queue.read_cursor += 1
if queue.read_cursor >= len(queue.buffer) {
queue.read_cursor = 0
}
}
} else {
message := <-queue.in
queue.ProcessIncoming(message)
}
}
}(&queue)
return in, out
}

@ -1,35 +0,0 @@
package graphvent
import (
"encoding/binary"
"testing"
)
func sendBatch(start, end uint64, in chan<- Message) {
for i := start; i <= end; i++ {
var id NodeID
binary.BigEndian.PutUint64(id[:], i)
in <- Message{id, nil}
}
}
func TestMessageQueue(t *testing.T) {
in, out := NewMessageQueue(10)
for i := uint64(0); i < 1000; i++ {
go sendBatch(1000*i, (1000*(i+1))-1, in)
}
seen := map[NodeID]any{}
for i := uint64(0); i < 1000*1000; i++ {
read := <-out
_, already_seen := seen[read.Node]
if already_seen {
t.Fatalf("Signal %d had duplicate NodeID %s", i, read.Node)
} else {
seen[read.Node] = nil
}
}
t.Logf("Processed 1M signals through queue")
}

File diff suppressed because it is too large Load Diff

@ -2,54 +2,27 @@ package graphvent
import (
"testing"
"time"
"crypto/rand"
"crypto/ed25519"
)
func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "node", "db"})
node_listener := NewListenerExt(10)
node, err := ctx.NewNode(nil, "Node", NewLockableExt(nil), node_listener)
fatalErr(t, err)
err = ctx.Stop()
fatalErr(t, err)
_, err = ctx.GetNode(node.ID)
fatalErr(t, err)
}
func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
n2_pub, n2_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
n1_id := KeyID(n1_pub)
n2_id := KeyID(n2_pub)
ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id)
n2_listener := NewListenerExt(10)
n2, err := ctx.NewNode(n2_key, "Node", n2_listener)
fatalErr(t, err)
n1, err := ctx.NewNode(n1_key, "Node", NewListenerExt(10))
fatalErr(t, err)
read_sig := NewReadSignal([]string{"buffer"})
msgs := []Message{{n1.ID, read_sig}}
err = ctx.Send(n2, msgs)
ctx := logTestContext(t, []string{"test", "db", "node", "policy"})
node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"})
fatalErr(t, err)
node := NewNode(ctx, RandID(), node_type)
node.Extensions[ACLExtType] = &ACLExt{
Delegations: NodeMap{},
}
context := NewWriteContext(ctx)
err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error {
ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %+v", ser)
return err
})
fatalErr(t, err)
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool {
return true
})
ctx.Nodes = NodeMap{}
_, err = LoadNode(ctx, node.ID)
fatalErr(t, err)
ctx.Log.Logf("test", "READ_RESULT: %+v", res)
}

@ -0,0 +1,396 @@
package graphvent
import (
"encoding/json"
"fmt"
)
type Policy interface {
Serializable[PolicyType]
Allows(context *StateContext, principal *Node, action string, node *Node) bool
}
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
type RequirementOfPolicy struct {
PerNodePolicy
}
func (policy *RequirementOfPolicy) Type() PolicyType {
return RequirementOfPolicyType
}
func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy {
return RequirementOfPolicy{
PerNodePolicy: NewPerNodePolicy(nodes),
}
}
// Check if any of principals dependencies are in the policy
func (policy *RequirementOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
lockable_ext, err := GetExt[*LockableExt](principal)
if err != nil {
return false
}
for dep_id, _ := range(lockable_ext.Dependencies) {
for node_id, actions := range(policy.NodeActions) {
if node_id == dep_id {
if actions.Allows(action) == true {
return true
}
break
}
}
}
return false
}
const ChildOfPolicyType = PolicyType("CHILD_OF")
type ChildOfPolicy struct {
PerNodePolicy
}
func (policy *ChildOfPolicy) Type() PolicyType {
return ChildOfPolicyType
}
func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy)
thread_ext, err := GetExt[*ThreadExt](principal)
if err != nil {
return false
}
parent := thread_ext.Parent
if parent != nil {
actions, exists := policy.NodeActions[parent.ID]
if exists == false {
return false
}
for _, a := range(actions) {
if a == action {
return true
}
}
}
return false
}
type Actions []string
func (actions Actions) Allows(action string) bool {
for _, a := range(actions) {
if a == action {
return true
}
}
return false
}
type NodeActions map[NodeID]Actions
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){
var j PerNodePolicyJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
node_actions := NodeActions{}
for id_str, actions := range(j.NodeActions) {
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
_, err = LoadNode(ctx, id)
if err != nil {
return nil, err
}
node_actions[id] = actions
}
return init_fn(node_actions)
}
}
func NewChildOfPolicy(node_actions NodeActions) ChildOfPolicy {
return ChildOfPolicy{
PerNodePolicy: NewPerNodePolicy(node_actions),
}
}
const ParentOfPolicyType = PolicyType("PARENT_OF")
type ParentOfPolicy struct {
PerNodePolicy
}
func (policy *ParentOfPolicy) Type() PolicyType {
return ParentOfPolicyType
}
func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy)
for id, actions := range(policy.NodeActions) {
thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id])
if err != nil {
continue
}
context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent)
if thread_ext.Parent != nil {
if thread_ext.Parent.ID == principal.ID {
for _, a := range(actions) {
if a == action {
return true
}
}
}
}
}
return false
}
func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy {
return ParentOfPolicy{
PerNodePolicy: NewPerNodePolicy(node_actions),
}
}
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
if node_actions == nil {
node_actions = NodeActions{}
}
return PerNodePolicy{
NodeActions: node_actions,
}
}
type PerNodePolicy struct {
NodeActions NodeActions
}
type PerNodePolicyJSON struct {
NodeActions map[string][]string `json:"node_actions"`
}
const PerNodePolicyType = PolicyType("PER_NODE")
func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType
}
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
node_actions := map[string][]string{}
for id, actions := range(policy.NodeActions) {
node_actions[id.String()] = actions
}
return json.MarshalIndent(&PerNodePolicyJSON{
NodeActions: node_actions,
}, "", " ")
}
func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for id, actions := range(policy.NodeActions) {
if id != principal.ID {
continue
}
for _, a := range(actions) {
if a == action {
return true
}
}
}
return false
}
// Extension to allow a node to hold ACL policies
type ACLPolicyExt struct {
Policies map[PolicyType]Policy
}
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExt struct {
Delegations NodeMap
}
func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil
}
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
delegations, err := RestoreNodeList(ctx, j.Delegations)
if err != nil {
return nil, err
}
return &ACLExt{
Delegations: delegations,
}, nil
}
func NodeList(nodes ...*Node) NodeMap {
m := NodeMap{}
for _, node := range(nodes) {
m[node.ID] = node
}
return m
}
func NewACLExt(delegations NodeMap) *ACLExt {
if delegations == nil {
delegations = NodeMap{}
}
return &ACLExt{
Delegations: delegations,
}
}
func (ext *ACLExt) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
const ACLExtType = ExtType("ACL")
func (ext *ACLExt) Type() ExtType {
return ACLExtType
}
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
type PolicyInfo struct {
Load PolicyLoadFunc
}
type ACLPolicyExtContext struct {
Types map[PolicyType]PolicyInfo
}
func NewACLPolicyExtContext() *ACLPolicyExtContext {
return &ACLPolicyExtContext{
Types: map[PolicyType]PolicyInfo{
PerNodePolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewPerNodePolicy(nodes)
return &policy, nil
}),
},
ParentOfPolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewParentOfPolicy(nodes)
return &policy, nil
}),
},
ChildOfPolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewChildOfPolicy(nodes)
return &policy, nil
}),
},
RequirementOfPolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewRequirementOfPolicy(nodes)
return &policy, nil
}),
},
},
}
}
func (ext *ACLPolicyExt) Serialize() ([]byte, error) {
policies := map[string][]byte{}
for name, policy := range(ext.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
policies[string(name)] = ser
}
return json.MarshalIndent(&struct{
Policies map[string][]byte `json:"policies"`
}{
Policies: policies,
}, "", " ")
}
func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil
}
func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt {
if policies == nil {
policies = map[PolicyType]Policy{}
}
return &ACLPolicyExt{
Policies: policies,
}
}
func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Policies map[string][]byte `json:"policies"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
policies := map[PolicyType]Policy{}
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(*ACLPolicyExtContext)
for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false {
return nil, fmt.Errorf("%s is not a known policy type", name)
}
policy, err := policy_def.Load(ctx, ser)
if err != nil {
return nil, err
}
policies[PolicyType(name)] = policy
}
return NewACLPolicyExt(policies), nil
}
const ACLPolicyExtType = ExtType("ACL_POLICIES")
func (ext *ACLPolicyExt) Type() ExtType {
return ACLPolicyExtType
}
// Check if the extension allows the principal to perform action on node
func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext)
for _, policy := range(ext.Policies) {
context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy)
if policy.Allows(context, principal, action, node) == true {
return true
}
}
return false
}

@ -1,744 +0,0 @@
package graphvent
import (
"crypto/sha512"
"encoding/binary"
"fmt"
"reflect"
"math"
)
type SerializedType uint64
func (t SerializedType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type ExtType SerializedType
func (t ExtType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type NodeType SerializedType
func (t NodeType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type SignalType SerializedType
func (t SignalType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type FieldTag SerializedType
func (t FieldTag) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
func NodeTypeFor(name string) NodeType {
digest := []byte("GRAPHVENT_NODE - " + name)
hash := sha512.Sum512(digest)
return NodeType(binary.BigEndian.Uint64(hash[0:8]))
}
func SerializeType(t fmt.Stringer) SerializedType {
digest := []byte(t.String())
hash := sha512.Sum512(digest)
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
}
func SerializedTypeFor[T any]() SerializedType {
return SerializeType(reflect.TypeFor[T]())
}
func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType {
return ExtType(SerializedTypeFor[E]())
}
func ExtTypeOf(t reflect.Type) ExtType {
return ExtType(SerializeType(t.Elem()))
}
func SignalTypeFor[S Signal]() SignalType {
return SignalType(SerializedTypeFor[S]())
}
func Hash(base, data string) SerializedType {
digest := []byte(base + ":" + data)
hash := sha512.Sum512(digest)
return SerializedType(binary.BigEndian.Uint64(hash[0:8]))
}
func GetFieldTag(tag string) FieldTag {
return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag))
}
func TypeStack(ctx *Context, t reflect.Type, data []byte) (int, error) {
info, registered := ctx.Types[t]
if registered {
binary.BigEndian.PutUint64(data, uint64(info.Serialized))
return 8, nil
} else {
switch t.Kind() {
case reflect.Map:
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Map)))
key_written, err := TypeStack(ctx, t.Key(), data[8:])
if err != nil {
return 0, err
}
elem_written, err := TypeStack(ctx, t.Elem(), data[8 + key_written:])
if err != nil {
return 0, err
}
return 8 + key_written + elem_written, nil
case reflect.Pointer:
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Pointer)))
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
if err != nil {
return 0, err
}
return 8 + elem_written, nil
case reflect.Slice:
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Slice)))
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
if err != nil {
return 0, err
}
return 8 + elem_written, nil
case reflect.Array:
binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Array)))
binary.BigEndian.PutUint64(data[8:], uint64(t.Len()))
elem_written, err := TypeStack(ctx, t.Elem(), data[16:])
if err != nil {
return 0, err
}
return 16 + elem_written, nil
default:
return 0, fmt.Errorf("Hit %s, which is not a registered type", t.String())
}
}
}
func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
first_bytes, left := split(stack, 8)
first := SerializedType(binary.BigEndian.Uint64(first_bytes))
info, registered := ctx.TypesReverse[first]
if registered {
return info.Reflect, left, nil
} else {
switch first {
case SerializeType(reflect.Map):
key_type, after_key, err := UnwrapStack(ctx, left)
if err != nil {
return nil, nil, err
}
elem_type, after_elem, err := UnwrapStack(ctx, after_key)
if err != nil {
return nil, nil, err
}
return reflect.MapOf(key_type, elem_type), after_elem, nil
case SerializeType(reflect.Pointer):
elem_type, rest, err := UnwrapStack(ctx, left)
if err != nil {
return nil, nil, err
}
return reflect.PointerTo(elem_type), rest, nil
case SerializeType(reflect.Slice):
elem_type, rest, err := UnwrapStack(ctx, left)
if err != nil {
return nil, nil, err
}
return reflect.SliceOf(elem_type), rest, nil
case SerializeType(reflect.Array):
length_bytes, left := split(left, 8)
length := int(binary.BigEndian.Uint64(length_bytes))
elem_type, rest, err := UnwrapStack(ctx, left)
if err != nil {
return nil, nil, err
}
return reflect.ArrayOf(length, elem_type), rest, nil
default:
return nil, nil, fmt.Errorf("Type stack %+v not recognized", stack)
}
}
}
func Serialize[T any](ctx *Context, value T, data []byte) (int, error) {
return SerializeValue(ctx, reflect.ValueOf(&value).Elem(), data)
}
func Deserialize[T any](ctx *Context, data []byte) (T, error) {
reflect_type := reflect.TypeFor[T]()
var zero T
value, left, err := DeserializeValue(ctx, data, reflect_type)
if err != nil {
return zero, err
} else if len(left) != 0 {
return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
} else if value.Type() != reflect_type {
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
}
return value.Interface().(T), nil
}
func SerializedSize(ctx *Context, value reflect.Value) (int, error) {
var sizefn SerializedSizeFn = nil
info, registered := ctx.Types[value.Type()]
if registered {
sizefn = info.SerializedSize
}
if sizefn == nil {
switch value.Type().Kind() {
case reflect.Bool:
return 1, nil
case reflect.Int8:
return 1, nil
case reflect.Int16:
return 2, nil
case reflect.Int32:
return 4, nil
case reflect.Int64:
fallthrough
case reflect.Int:
return 8, nil
case reflect.Uint8:
return 1, nil
case reflect.Uint16:
return 2, nil
case reflect.Uint32:
return 4, nil
case reflect.Uint64:
fallthrough
case reflect.Uint:
return 8, nil
case reflect.Float32:
return 4, nil
case reflect.Float64:
return 8, nil
case reflect.String:
return 8 + value.Len(), nil
case reflect.Pointer:
if value.IsNil() {
return 1, nil
} else {
elem_len, err := SerializedSize(ctx, value.Elem())
if err != nil {
return 0, err
} else {
return 1 + elem_len, nil
}
}
case reflect.Slice:
if value.IsNil() {
return 1, nil
} else {
elem_total := 0
for i := 0; i < value.Len(); i++ {
elem_len, err := SerializedSize(ctx, value.Index(i))
if err != nil {
return 0, err
}
elem_total += elem_len
}
return 9 + elem_total, nil
}
case reflect.Array:
total := 0
for i := 0; i < value.Len(); i++ {
elem_len, err := SerializedSize(ctx, value.Index(i))
if err != nil {
return 0, err
}
total += elem_len
}
return total, nil
case reflect.Map:
if value.IsNil() {
return 1, nil
} else {
key := reflect.New(value.Type().Key()).Elem()
val := reflect.New(value.Type().Elem()).Elem()
iter := value.MapRange()
total := 0
for iter.Next() {
key.SetIterKey(iter)
k, err := SerializedSize(ctx, key)
if err != nil {
return 0, err
}
total += k
val.SetIterValue(iter)
v, err := SerializedSize(ctx, val)
if err != nil {
return 0, err
}
total += v
}
return 9 + total, nil
}
case reflect.Struct:
if registered == false {
return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type())
} else {
field_total := 0
for _, field_info := range(info.Fields) {
field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index))
if err != nil {
return 0, err
}
field_total += 8
field_total += field_size
}
return 8 + field_total, nil
}
case reflect.Interface:
// TODO get size of TypeStack instead of just using 128
elem_size, err := SerializedSize(ctx, value.Elem())
if err != nil {
return 0, err
}
return 128 + elem_size, nil
default:
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
}
} else {
return sizefn(ctx, value)
}
}
func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) {
var serialize SerializeFn = nil
info, registered := ctx.Types[value.Type()]
if registered {
serialize = info.Serialize
}
if serialize == nil {
switch value.Type().Kind() {
case reflect.Bool:
if value.Bool() {
data[0] = 0xFF
} else {
data[0] = 0x00
}
return 1, nil
case reflect.Int8:
data[0] = byte(value.Int())
return 1, nil
case reflect.Int16:
binary.BigEndian.PutUint16(data, uint16(value.Int()))
return 2, nil
case reflect.Int32:
binary.BigEndian.PutUint32(data, uint32(value.Int()))
return 4, nil
case reflect.Int64:
fallthrough
case reflect.Int:
binary.BigEndian.PutUint64(data, uint64(value.Int()))
return 8, nil
case reflect.Uint8:
data[0] = byte(value.Uint())
return 1, nil
case reflect.Uint16:
binary.BigEndian.PutUint16(data, uint16(value.Uint()))
return 2, nil
case reflect.Uint32:
binary.BigEndian.PutUint32(data, uint32(value.Uint()))
return 4, nil
case reflect.Uint64:
fallthrough
case reflect.Uint:
binary.BigEndian.PutUint64(data, value.Uint())
return 8, nil
case reflect.Float32:
binary.BigEndian.PutUint32(data, math.Float32bits(float32(value.Float())))
return 4, nil
case reflect.Float64:
binary.BigEndian.PutUint64(data, math.Float64bits(value.Float()))
return 8, nil
case reflect.String:
binary.BigEndian.PutUint64(data, uint64(value.Len()))
copy(data[8:], []byte(value.String()))
return 8 + value.Len(), nil
case reflect.Pointer:
if value.IsNil() {
data[0] = 0x00
return 1, nil
} else {
data[0] = 0x01
written, err := SerializeValue(ctx, value.Elem(), data[1:])
if err != nil {
return 0, err
}
return 1 + written, nil
}
case reflect.Slice:
if value.IsNil() {
data[0] = 0x00
return 8, nil
} else {
data[0] = 0x01
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
total_written := 0
for i := 0; i < value.Len(); i++ {
written, err := SerializeValue(ctx, value.Index(i), data[9+total_written:])
if err != nil {
return 0, err
}
total_written += written
}
return 9 + total_written, nil
}
case reflect.Array:
total_written := 0
for i := 0; i < value.Len(); i++ {
written, err := SerializeValue(ctx, value.Index(i), data[total_written:])
if err != nil {
return 0, err
}
total_written += written
}
return total_written, nil
case reflect.Map:
if value.IsNil() {
data[0] = 0x00
return 1, nil
} else {
data[0] = 0x01
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
key := reflect.New(value.Type().Key()).Elem()
val := reflect.New(value.Type().Elem()).Elem()
iter := value.MapRange()
total_written := 0
for iter.Next() {
key.SetIterKey(iter)
val.SetIterValue(iter)
k, err := SerializeValue(ctx, key, data[9+total_written:])
if err != nil {
return 0, err
}
total_written += k
v, err := SerializeValue(ctx, val, data[9+total_written:])
if err != nil {
return 0, err
}
total_written += v
}
return 9 + total_written, nil
}
case reflect.Struct:
if registered == false {
return 0, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type())
} else {
binary.BigEndian.PutUint64(data, uint64(len(info.Fields)))
total_written := 0
for field_tag, field_info := range(info.Fields) {
binary.BigEndian.PutUint64(data[8+total_written:], uint64(field_tag))
total_written += 8
written, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index), data[8+total_written:])
if err != nil {
return 0, err
}
total_written += written
}
return 8 + total_written, nil
}
case reflect.Interface:
type_written, err := TypeStack(ctx, value.Elem().Type(), data)
elem_written, err := SerializeValue(ctx, value.Elem(), data[type_written:])
if err != nil {
return 0, err
}
return type_written + elem_written, nil
default:
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
}
} else {
return serialize(ctx, value, data)
}
}
func split(data []byte, n int) ([]byte, []byte) {
return data[:n], data[n:]
}
func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) {
var deserialize DeserializeFn = nil
info, registered := ctx.Types[t]
if registered {
deserialize = info.Deserialize
}
if deserialize == nil {
switch t.Kind() {
case reflect.Bool:
used, left := split(data, 1)
value := reflect.New(t).Elem()
value.SetBool(used[0] != 0x00)
return value, left, nil
case reflect.Int8:
used, left := split(data, 1)
value := reflect.New(t).Elem()
value.SetInt(int64(used[0]))
return value, left, nil
case reflect.Int16:
used, left := split(data, 2)
value := reflect.New(t).Elem()
value.SetInt(int64(binary.BigEndian.Uint16(used)))
return value, left, nil
case reflect.Int32:
used, left := split(data, 4)
value := reflect.New(t).Elem()
value.SetInt(int64(binary.BigEndian.Uint32(used)))
return value, left, nil
case reflect.Int64:
fallthrough
case reflect.Int:
used, left := split(data, 8)
value := reflect.New(t).Elem()
value.SetInt(int64(binary.BigEndian.Uint64(used)))
return value, left, nil
case reflect.Uint8:
used, left := split(data, 1)
value := reflect.New(t).Elem()
value.SetUint(uint64(used[0]))
return value, left, nil
case reflect.Uint16:
used, left := split(data, 2)
value := reflect.New(t).Elem()
value.SetUint(uint64(binary.BigEndian.Uint16(used)))
return value, left, nil
case reflect.Uint32:
used, left := split(data, 4)
value := reflect.New(t).Elem()
value.SetUint(uint64(binary.BigEndian.Uint32(used)))
return value, left, nil
case reflect.Uint64:
fallthrough
case reflect.Uint:
used, left := split(data, 8)
value := reflect.New(t).Elem()
value.SetUint(binary.BigEndian.Uint64(used))
return value, left, nil
case reflect.Float32:
used, left := split(data, 4)
value := reflect.New(t).Elem()
value.SetFloat(float64(math.Float32frombits(binary.BigEndian.Uint32(used))))
return value, left, nil
case reflect.Float64:
used, left := split(data, 8)
value := reflect.New(t).Elem()
value.SetFloat(math.Float64frombits(binary.BigEndian.Uint64(used)))
return value, left, nil
case reflect.String:
length, after_len := split(data, 8)
used, left := split(after_len, int(binary.BigEndian.Uint64(length)))
value := reflect.New(t).Elem()
value.SetString(string(used))
return value, left, nil
case reflect.Pointer:
flags, after_flags := split(data, 1)
value := reflect.New(t).Elem()
if flags[0] == 0x00 {
value.SetZero()
return value, after_flags, nil
} else {
elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
value.Set(elem_value.Addr())
return value, after_elem, nil
}
case reflect.Slice:
nil_byte := data[0]
data = data[1:]
if nil_byte == 0x00 {
return reflect.New(t).Elem(), data, nil
} else {
len_bytes, left := split(data, 8)
length := int(binary.BigEndian.Uint64(len_bytes))
value := reflect.MakeSlice(t, length, length)
for i := 0; i < length; i++ {
var elem_value reflect.Value
var err error
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
value.Index(i).Set(elem_value)
}
return value, left, nil
}
case reflect.Array:
value := reflect.New(t).Elem()
left := data
for i := 0; i < t.Len(); i++ {
var elem_value reflect.Value
var err error
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
value.Index(i).Set(elem_value)
}
return value, left, nil
case reflect.Map:
flags, after_flags := split(data, 1)
if flags[0] == 0x00 {
return reflect.New(t).Elem(), after_flags, nil
} else {
len_bytes, left := split(after_flags, 8)
length := int(binary.BigEndian.Uint64(len_bytes))
value := reflect.MakeMapWithSize(t, length)
for i := 0; i < length; i++ {
var key_value reflect.Value
var val_value reflect.Value
var err error
key_value, left, err = DeserializeValue(ctx, left, t.Key())
if err != nil {
return reflect.Value{}, nil, err
}
val_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
value.SetMapIndex(key_value, val_value)
}
return value, left, nil
}
case reflect.Struct:
info, mapped := ctx.Types[t]
if mapped {
value := reflect.New(t).Elem()
num_field_bytes, left := split(data, 8)
num_fields := int(binary.BigEndian.Uint64(num_field_bytes))
for i := 0; i < num_fields; i++ {
var tag_bytes []byte
tag_bytes, left = split(left, 8)
field_tag := FieldTag(binary.BigEndian.Uint64(tag_bytes))
field_info, mapped := info.Fields[field_tag]
if mapped {
var field_val reflect.Value
var err error
field_val, left, err = DeserializeValue(ctx, left, field_info.Type)
if err != nil {
return reflect.Value{}, nil, err
}
value.FieldByIndex(field_info.Index).Set(field_val)
} else {
return reflect.Value{}, nil, fmt.Errorf("Unknown field %s on struct %s", field_tag, t)
}
}
if info.PostDeserializeIndex != -1 {
post_deserialize_method := value.Addr().Method(info.PostDeserializeIndex)
post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)})
}
return value, left, nil
} else {
return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t)
}
case reflect.Interface:
elem_type, rest, err := UnwrapStack(ctx, data)
if err != nil {
return reflect.Value{}, nil, err
}
elem_val, left, err := DeserializeValue(ctx, rest, elem_type)
if err != nil {
return reflect.Value{}, nil, err
}
val := reflect.New(t).Elem()
val.Set(elem_val)
return val, left, nil
default:
return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t)
}
} else {
return deserialize(ctx, data)
}
}

@ -1,176 +0,0 @@
package graphvent
import (
"testing"
"reflect"
"github.com/google/uuid"
)
func testTypeStack[T any](t *testing.T, ctx *Context) {
buffer := [1024]byte{}
reflect_type := reflect.TypeFor[T]()
written, err := TypeStack(ctx, reflect_type, buffer[:])
fatalErr(t, err)
stack := buffer[:written]
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
fatalErr(t, err)
if len(rest) != 0 {
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
}
if unwrapped_type != reflect_type {
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
}
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
}
func TestSerializeTypes(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
testTypeStack[int](t, ctx)
testTypeStack[map[int]string](t, ctx)
testTypeStack[string](t, ctx)
testTypeStack[*string](t, ctx)
testTypeStack[*map[string]*map[*string]int](t, ctx)
testTypeStack[[5]int](t, ctx)
testTypeStack[uuid.UUID](t, ctx)
testTypeStack[NodeID](t, ctx)
}
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
deserialized, err := Deserialize[T](ctx, serialized)
fatalErr(t, err)
if value != deserialized {
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
}
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
}
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
deserialized, err := Deserialize[L](ctx, serialized)
fatalErr(t, err)
for i, item := range(value) {
if item != deserialized[i] {
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
}
}
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
}
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
deserialized, err := Deserialize[P](ctx, serialized)
fatalErr(t, err)
if value == nil && deserialized == nil {
ctx.Log.Logf("test", "Deserialized nil")
} else if value == nil {
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
} else if deserialized == nil {
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
} else if *deserialized != *value {
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
} else {
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
}
}
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
deserialized, err := Deserialize[T](ctx, serialized)
fatalErr(t, err)
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
}
func TestSerializeValues(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
testSerializeCompare[int8](t, ctx, -64)
testSerializeCompare[int16](t, ctx, -64)
testSerializeCompare[int32](t, ctx, -64)
testSerializeCompare[int64](t, ctx, -64)
testSerializeCompare[int](t, ctx, -64)
testSerializeCompare[uint8](t, ctx, 64)
testSerializeCompare[uint16](t, ctx, 64)
testSerializeCompare[uint32](t, ctx, 64)
testSerializeCompare[uint64](t, ctx, 64)
testSerializeCompare[uint](t, ctx, 64)
testSerializeCompare[string](t, ctx, "test")
a := 12
testSerializePointer[*int](t, ctx, &a)
b := "test"
testSerializePointer[*string](t, ctx, nil)
testSerializePointer[*string](t, ctx, &b)
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
testSerializeCompare[bool](t, ctx, true)
testSerializeCompare[bool](t, ctx, false)
testSerializeCompare[int](t, ctx, -1)
testSerializeCompare[uint](t, ctx, 1)
testSerializeCompare[NodeID](t, ctx, RandID())
testSerializeCompare[*int](t, ctx, nil)
testSerializeCompare(t, ctx, "string")
testSerialize(t, ctx, map[string]string{
"Test": "Test",
"key": "String",
"": "",
})
testSerialize[map[string]string](t, ctx, nil)
testSerialize(t, ctx, NewListenerExt(10))
node, err := ctx.NewNode(nil, "Node")
fatalErr(t, err)
testSerialize(t, ctx, node)
}

@ -1,261 +1,111 @@
package graphvent
import (
"fmt"
"time"
"github.com/google/uuid"
"encoding/json"
)
type TimeoutSignal struct {
ResponseHeader
}
func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal {
return &TimeoutSignal{
NewResponseHeader(req_id),
}
}
func (signal TimeoutSignal) String() string {
return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader)
}
type SignalHeader struct {
Id uuid.UUID `gv:"id"`
}
func (signal SignalHeader) ID() uuid.UUID {
return signal.Id
}
func (header SignalHeader) String() string {
return fmt.Sprintf("%s", header.Id)
}
type SignalDirection int
const (
Up SignalDirection = iota
Down
Direct
)
type ResponseSignal interface {
Signal
ResponseID() uuid.UUID
}
type SignalType string
type ResponseHeader struct {
SignalHeader
ReqID uuid.UUID `gv:"req_id"`
type Signal interface {
Serializable[SignalType]
Direction() SignalDirection
}
func (header ResponseHeader) ResponseID() uuid.UUID {
return header.ReqID
type BaseSignal struct {
SignalDirection SignalDirection `json:"direction"`
SignalType SignalType `json:"type"`
}
func (header ResponseHeader) String() string {
return fmt.Sprintf("%s for %s", header.Id, header.ReqID)
func (signal BaseSignal) Type() SignalType {
return signal.SignalType
}
type Signal interface {
fmt.Stringer
ID() uuid.UUID
func (signal BaseSignal) Direction() SignalDirection {
return signal.SignalDirection
}
func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) {
signals := []Signal{}
var timeout_channel <- chan time.Time
if timeout > 0 {
timeout_channel = time.After(timeout)
}
for true {
select {
case signal := <- listener:
if signal == nil {
return nil, signals, fmt.Errorf("LISTENER_CLOSED")
}
resp_signal, ok := signal.(ResponseSignal)
if ok == true && resp_signal.ResponseID() == req_id {
return resp_signal, signals, nil
} else {
signals = append(signals, signal)
}
case <-timeout_channel:
return nil, signals, fmt.Errorf("LISTENER_TIMEOUT")
}
}
return nil, signals, fmt.Errorf("UNREACHABLE")
func (signal BaseSignal) Serialize() ([]byte, error) {
return json.MarshalIndent(signal, "", " ")
}
//TODO: Add []Signal return as well for other signals
func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check func(S)bool) (S, error) {
var zero S
var timeout_channel <- chan time.Time
if timeout > 0 {
timeout_channel = time.After(timeout)
}
for true {
select {
case signal := <- listener:
if signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED")
}
sig, ok := signal.(S)
if ok == true {
if check(sig) == true {
return sig, nil
}
}
case <-timeout_channel:
return zero, fmt.Errorf("LISTENER_TIMEOUT")
}
func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal {
signal := BaseSignal{
SignalDirection: direction,
SignalType: signal_type,
}
return zero, fmt.Errorf("LOOP_ENDED")
return signal
}
func NewSignalHeader() SignalHeader {
return SignalHeader{
uuid.New(),
}
func NewDownSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Down)
}
func NewResponseHeader(req_id uuid.UUID) ResponseHeader {
return ResponseHeader{
NewSignalHeader(),
req_id,
}
func NewUpSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Up)
}
type SuccessSignal struct {
ResponseHeader
func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct)
}
func (signal SuccessSignal) String() string {
return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader)
}
var AbortSignal = NewBaseSignal("abort", Down)
var StopSignal = NewBaseSignal("stop", Down)
func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal {
return &SuccessSignal{
NewResponseHeader(req_id),
}
type IDSignal struct {
BaseSignal
ID NodeID `json:"id"`
}
type ErrorSignal struct {
ResponseHeader
Error string
}
func (signal ErrorSignal) String() string {
return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error)
}
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal {
return &ErrorSignal{
NewResponseHeader(req_id),
fmt.Sprintf(fmt_string, args...),
func (signal IDSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
}
type ACLTimeoutSignal struct {
ResponseHeader
}
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
sig := &ACLTimeoutSignal{
NewResponseHeader(req_id),
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
return IDSignal{
BaseSignal: NewBaseSignal(signal_type, direction),
ID: id,
}
return sig
}
type StatusSignal struct {
SignalHeader
Source NodeID `gv:"source"`
Fields []string `gv:"fields"`
}
func (signal StatusSignal) String() string {
return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Fields)
}
func NewStatusSignal(source NodeID, fields []string) *StatusSignal {
return &StatusSignal{
NewSignalHeader(),
source,
fields,
}
}
type LinkSignal struct {
SignalHeader
NodeID NodeID
Action string
}
const (
LinkActionBase = "LINK_ACTION"
LinkActionAdd = "ADD"
)
func NewLinkSignal(action string, id NodeID) Signal {
return &LinkSignal{
NewSignalHeader(),
id,
action,
}
}
type LockSignal struct {
SignalHeader
}
func (signal LockSignal) String() string {
return fmt.Sprintf("LockSignal(%s)", signal.SignalHeader)
IDSignal
Status string `json:"status"`
}
func NewLockSignal() *LockSignal {
return &LockSignal{
NewSignalHeader(),
func (signal StatusSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
}
type UnlockSignal struct {
SignalHeader
}
func (signal UnlockSignal) String() string {
return fmt.Sprintf("UnlockSignal(%s)", signal.SignalHeader)
}
func NewUnlockSignal() *UnlockSignal {
return &UnlockSignal{
NewSignalHeader(),
func NewStatusSignal(status string, source NodeID) StatusSignal {
return StatusSignal{
IDSignal: NewIDSignal("status", Up, source),
Status: status,
}
}
type ReadSignal struct {
SignalHeader
Fields []string `json:"extensions"`
}
func (signal ReadSignal) String() string {
return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Fields)
type StartChildSignal struct {
IDSignal
Action string `json:"action"`
}
func NewReadSignal(fields []string) *ReadSignal {
return &ReadSignal{
NewSignalHeader(),
fields,
func NewStartChildSignal(child_id NodeID, action string) StartChildSignal {
return StartChildSignal{
IDSignal: NewIDSignal("start_child", Direct, child_id),
Action: action,
}
}
type ReadResultSignal struct {
ResponseHeader
NodeID NodeID
NodeType NodeType
Fields map[string]any
}
func (signal ReadResultSignal) String() string {
return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Fields)
}
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, fields map[string]any) *ReadResultSignal {
return &ReadResultSignal{
NewResponseHeader(req_id),
node_id,
node_type,
fields,
}
}

@ -0,0 +1,736 @@
package graphvent
import (
"fmt"
"time"
"sync"
"errors"
"encoding/json"
"crypto/sha512"
"encoding/binary"
)
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error)
type ThreadHandlers map[SignalType]ThreadHandler
type InfoType string
func (t InfoType) String() string {
return string(t)
}
type Info interface {
Serializable[InfoType]
}
// Data required by a parent thread to restore it's children
type ParentInfo struct {
Start bool `json:"start"`
StartAction string `json:"start_action"`
RestoreAction string `json:"restore_action"`
}
const ParentInfoType = InfoType("PARENT")
func (info *ParentInfo) Type() InfoType {
return ParentInfoType
}
func (info *ParentInfo) Serialize() ([]byte, error) {
return json.MarshalIndent(info, "", " ")
}
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
type ThreadType string
func (thread ThreadType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type ThreadInfo struct {
Actions ThreadActions
Handlers ThreadHandlers
}
type InfoLoadFunc func([]byte)(Info, error)
type ThreadExtContext struct {
Types map[ThreadType]ThreadInfo
Loads map[InfoType]InfoLoadFunc
}
const BaseThreadType = ThreadType("BASE")
func NewThreadExtContext() *ThreadExtContext {
return &ThreadExtContext{
Types: map[ThreadType]ThreadInfo{
BaseThreadType: ThreadInfo{
Actions: BaseThreadActions,
Handlers: BaseThreadHandlers,
},
},
Loads: map[InfoType]InfoLoadFunc{
ParentInfoType: func(data []byte) (Info, error) {
var info ParentInfo
err := json.Unmarshal(data, &info)
if err != nil {
return nil, err
}
return &info, nil
},
},
}
}
func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error {
if actions == nil || handlers == nil {
return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type)
}
_, exists := ctx.Types[thread_type]
if exists == true {
return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type)
}
ctx.Types[thread_type] = ThreadInfo{
Actions: actions,
Handlers: handlers,
}
return nil
}
func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error {
if load_fn == nil {
return fmt.Errorf("Cannot register %s with nil load_fn", info_type)
}
_, exists := ctx.Loads[info_type]
if exists == true {
return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type)
}
ctx.Loads[info_type] = load_fn
return nil
}
type ThreadExt struct {
Actions ThreadActions
Handlers ThreadHandlers
ThreadType ThreadType
SignalChan chan Signal
TimeoutChan <-chan time.Time
ChildWaits sync.WaitGroup
ActiveLock sync.Mutex
Active bool
State string
Parent *Node
Children map[NodeID]ChildInfo
ActionQueue []QueuedAction
NextAction *QueuedAction
}
type ThreadExtJSON struct {
State string `json:"state"`
Type string `json:"type"`
Parent string `json:"parent"`
Children map[string]map[string][]byte `json:"children"`
ActionQueue []QueuedAction
}
func (ext *ThreadExt) Serialize() ([]byte, error) {
children := map[string]map[string][]byte{}
for id, child := range(ext.Children) {
id_str := id.String()
children[id_str] = map[string][]byte{}
for info_type, info := range(child.Infos) {
var err error
children[id_str][string(info_type)], err = info.Serialize()
if err != nil {
return nil, err
}
}
}
return json.MarshalIndent(&ThreadExtJSON{
State: ext.State,
Type: string(ext.ThreadType),
Parent: SaveNode(ext.Parent),
Children: children,
ActionQueue: ext.ActionQueue,
}, "", " ")
}
func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) {
if children == nil {
children = map[NodeID]ChildInfo{}
}
if action_queue == nil {
action_queue = []QueuedAction{}
}
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
if err != nil {
return nil, err
}
type_info, exists := thread_ctx.Types[thread_type]
if exists == false {
return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type)
}
next_action, timeout_chan := SoonestAction(action_queue)
return &ThreadExt{
ThreadType: thread_type,
Actions: type_info.Actions,
Handlers: type_info.Handlers,
SignalChan: make(chan Signal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan,
Active: false,
State: state,
Parent: parent,
Children: children,
ActionQueue: action_queue,
NextAction: next_action,
}, nil
}
const THREAD_BUFFER_SIZE int = 1024
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
var j ThreadExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
ctx.Log.Logf("db", "DB_LOAD_THREAD_EXT_JSON: %+v", j)
parent, err := RestoreNode(ctx, j.Parent)
if err != nil {
return nil, err
}
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
if err != nil {
return nil, err
}
children := map[NodeID]ChildInfo{}
for id_str, infos := range(j.Children) {
child_node, err := RestoreNode(ctx, id_str)
if err != nil {
return nil, err
}
child_infos := map[InfoType]Info{}
for info_type_str, info_data := range(infos) {
info_type := InfoType(info_type_str)
info_load, exists := thread_ctx.Loads[info_type]
if exists == false {
return nil, fmt.Errorf("%s is not a known InfoType in ThreacExrContxt", info_type)
}
info, err := info_load(info_data)
if err != nil {
return nil, err
}
child_infos[info_type] = info
}
children[child_node.ID] = ChildInfo{
Child: child_node,
Infos: child_infos,
}
}
return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue)
}
const ThreadExtType = ExtType("THREAD")
func (ext *ThreadExt) Type() ExtType {
return ThreadExtType
}
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
}
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(actions) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
}
}
if soonest_action != nil {
return soonest_action, time.After(time.Until(soonest_action.Timeout))
} else {
return nil, nil
}
}
func (ext *ThreadExt) ChildList() []*Node {
ret := make([]*Node, len(ext.Children))
i := 0
for _, info := range(ext.Children) {
ret[i] = info.Child
i += 1
}
return ret
}
// Assumed that thread is already locked for signal
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
var err error
switch signal.Direction() {
case Up:
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
if ext.Parent != nil {
if ext.Parent.ID != node.ID {
return SendSignal(context, ext.Parent, node, signal)
}
}
return nil
})
case Down:
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
for _, info := range(ext.Children) {
err := SendSignal(context, info.Child, node, signal)
if err != nil {
return err
}
}
return nil
})
case Direct:
err = nil
default:
return fmt.Errorf("Invalid signal direction %d", signal.Direction())
}
ext.SignalChan <- signal
return err
}
func UnlinkThreads(context *StateContext, principal *Node, thread *Node, child *Node) error {
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
child_ext, err := GetExt[*ThreadExt](child)
if err != nil {
return err
}
return UpdateStates(context, principal, ACLMap{
thread.ID: ACLInfo{thread, []string{"children"}},
child.ID: ACLInfo{child, []string{"parent"}},
}, func(context *StateContext) error {
_, is_child := thread_ext.Children[child.ID]
if is_child == false {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID, thread.ID)
}
delete(thread_ext.Children, child.ID)
child_ext.Parent = nil
return nil
})
}
func checkIfChild(context *StateContext, id NodeID, cur *ThreadExt) (bool, error) {
for _, info := range(cur.Children) {
child := info.Child
if child.ID == id {
return true, nil
}
child_ext, err := GetExt[*ThreadExt](child)
if err != nil {
return false, err
}
var is_child bool
err = UpdateStates(context, child, NewACLInfo(child, []string{"children"}), func(context *StateContext) error {
is_child, err = checkIfChild(context, id, child_ext)
return err
})
if err != nil {
return false, err
}
if is_child {
return true, nil
}
}
return false, nil
}
// Links child to parent with info as the associated info
// Continues the write context with princ, getting children for thread and parent for child
func LinkThreads(context *StateContext, principal *Node, thread *Node, info ChildInfo) error {
if context == nil || principal == nil || thread == nil || info.Child == nil {
return fmt.Errorf("invalid input")
}
child := info.Child
if thread.ID == child.ID {
return fmt.Errorf("Will not link %s as a child of itself", thread.ID)
}
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
child_ext, err := GetExt[*ThreadExt](child)
if err != nil {
return err
}
return UpdateStates(context, principal, ACLMap{
child.ID: ACLInfo{Node: child, Resources: []string{"parent"}},
thread.ID: ACLInfo{Node: thread, Resources: []string{"children"}},
}, func(context *StateContext) error {
if child_ext.Parent != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID)
}
is_child, err := checkIfChild(context, thread.ID, child_ext)
if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID, child.ID)
}
is_child, err = checkIfChild(context, child.ID, thread_ext)
if err != nil {
return err
} else if is_child == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID, child.ID)
}
// TODO check for info types
thread_ext.Children[child.ID] = info
child_ext.Parent = thread
return nil
})
}
type ChildInfo struct {
Child *Node
Infos map[InfoType]Info
}
func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo {
if infos == nil {
infos = map[InfoType]Info{}
}
return ChildInfo{
Child: child,
Infos: infos,
}
}
func (ext *ThreadExt) SetActive(active bool) error {
ext.ActiveLock.Lock()
defer ext.ActiveLock.Unlock()
if ext.Active == true && active == true {
return fmt.Errorf("alreday active, cannot set active")
} else if ext.Active == false && active == false {
return fmt.Errorf("already inactive, canot set inactive")
}
ext.Active = active
return nil
}
func (ext *ThreadExt) SetState(state string) error {
ext.State = state
return nil
}
// Requires the read permission of threads children
func FindChild(context *StateContext, principal *Node, thread *Node, id NodeID) (*Node, error) {
if thread == nil {
panic("cannot recurse through nil")
}
if id == thread.ID {
return thread, nil
}
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return nil, err
}
var found *Node = nil
err = UseStates(context, principal, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
for _, info := range(thread_ext.Children) {
found, err = FindChild(context, principal, info.Child, id)
if err != nil {
return err
}
if found != nil {
return nil
}
}
return nil
})
return found, err
}
func ChildGo(ctx * Context, thread_ext *ThreadExt, child *Node, first_action string) {
thread_ext.ChildWaits.Add(1)
go func(child *Node) {
defer thread_ext.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action)
if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID, err)
} else {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID)
}
}(child)
}
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
thread_ext, err := GetExt[*ThreadExt](thread)
if err != nil {
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID, first_action)
err = thread_ext.SetActive(true)
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
return err
}
next_action := first_action
for next_action != "" {
action, exists := thread_ext.Actions[next_action]
if exists == false {
error_str := fmt.Sprintf("%s is not a valid action", next_action)
return errors.New(error_str)
}
ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID, next_action)
next_action, err = action(ctx, thread, thread_ext)
if err != nil {
return err
}
}
err = thread_ext.SetActive(false)
if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
return err
}
ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID)
return nil
}
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
sig, ok := signal.(IDSignal)
if ok == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED_BAD_CAST")
return nil
}
info, exists := thread_ext.Children[sig.ID]
if exists == false {
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil
}
parent_info, exists := info.Infos["parent"].(*ParentInfo)
if exists == false {
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
}
if parent_info.Start == true {
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
}
return nil
})
if err != nil {
} else {
}
return "wait", nil
}
// Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
sig, ok := signal.(StartChildSignal)
if ok == false {
return "wait", nil
}
context := NewWriteContext(ctx)
return "wait", UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
info, exists:= thread_ext.Children[sig.ID]
if exists == false {
return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID)
}
parent_info, exists := info.Infos["parent"].(*ParentInfo)
if exists == false {
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
}
parent_info.Start = true
ChildGo(ctx, thread_ext, info.Child, sig.Action)
return nil
})
}
// Helper function to restore threads that should be running from a parents restore action
// Starts a write context, so cannot be called from either a write or read context
func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error {
context := NewWriteContext(ctx)
return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"state"}), func(context *StateContext) error {
for _, info := range(thread_ext.Children) {
child_ext, err := GetExt[*ThreadExt](info.Child)
if err != nil {
return err
}
parent_info := info.Infos[ParentInfoType].(*ParentInfo)
if parent_info.Start == true && child_ext.State != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true {
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
} else {
ChildGo(ctx, thread_ext, info.Child, parent_info.RestoreAction)
}
}
}
return nil
})
})
}
// Helper function to be called during a threads start action, sets the thread state to started
// Starts a write context, so cannot be called from either a write or read context
// Returns "wait", nil on success, so the first return value can be ignored safely
func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := LockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
if err != nil {
return err
}
return thread_ext.SetState("started")
})
if err != nil {
return "", err
}
context = NewReadContext(ctx)
return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID))
}
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s - %+v", thread.ID, thread_ext.ActionQueue)
for {
select {
case signal := <- thread_ext.SignalChan:
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID, signal)
signal_fn, exists := thread_ext.Handlers[signal.Type()]
if exists == true {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID, signal.Type())
return signal_fn(ctx, thread, thread_ext, signal)
} else {
ctx.Log.Logf("thread", "THREAD_NOHANDLER: %s - %s", thread.ID, signal.Type())
}
case <- thread_ext.TimeoutChan:
timeout_action := ""
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread_ext.NextAction.Action
thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
return nil
})
if err != nil {
ctx.Log.Logf("thread", "THREAD_TIMEOUT_ERR: %s - %e", thread.ID, err)
}
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID, timeout_action)
return timeout_action, nil
}
}
}
func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
context := NewWriteContext(ctx)
return "", UpdateStates(context, thread, NewACLInfo(thread, []string{"state"}), func(context *StateContext) error {
err := thread_ext.SetState("finished")
if err != nil {
return err
}
return UnlockLockables(context, map[NodeID]*Node{thread.ID: thread}, thread)
})
}
var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
context := NewReadContext(ctx)
err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
if err != nil {
return "", err
}
return "", ThreadAbortedError
}
// Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
context := NewReadContext(ctx)
err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
return "finish", err
}
// Default thread actions
var BaseThreadActions = ThreadActions{
"wait": ThreadWait,
"start": ThreadStart,
"finish": ThreadFinish,
}
// Default thread signal handlers
var BaseThreadHandlers = ThreadHandlers{
"abort": ThreadAbort,
"stop": ThreadStop,
}

@ -0,0 +1,120 @@
package graphvent
import (
"time"
"fmt"
"encoding/json"
"crypto/ecdsa"
"crypto/x509"
)
type ECDHExt struct {
Granted time.Time
Pubkey *ecdsa.PublicKey
Shared []byte
}
type ECDHExtJSON struct {
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"`
}
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil
}
const ECDHExtType = ExtType("ECDH")
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
if err != nil {
return nil, err
}
return json.MarshalIndent(&ECDHExtJSON{
Granted: ext.Granted,
Pubkey: pubkey,
Shared: ext.Shared,
}, "", " ")
}
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var j ECDHExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
if err != nil {
return nil, err
}
var pubkey *ecdsa.PublicKey
switch pub.(type) {
case *ecdsa.PublicKey:
pubkey = pub.(*ecdsa.PublicKey)
default:
return nil, fmt.Errorf("Invalid key type: %+v", pub)
}
extension := ECDHExt{
Granted: j.Granted,
Pubkey: pubkey,
Shared: j.Shared,
}
return &extension, nil
}
type GroupExt struct {
Members NodeMap
}
const GroupExtType = ExtType("GROUP")
func (ext *GroupExt) Type() ExtType {
return GroupExtType
}
func (ext *GroupExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&struct{
Members []string `json:"members"`
}{
Members: SaveNodeList(ext.Members),
}, "", " ")
}
func NewGroupExt(members NodeMap) *GroupExt {
if members == nil {
members = NodeMap{}
}
return &GroupExt{
Members: members,
}
}
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Members []string `json:"members"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
members, err := RestoreNodeList(ctx, j.Members)
if err != nil {
return nil, err
}
return NewGroupExt(members), nil
}
func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil
}