Compare commits

..

20 Commits

Author SHA1 Message Date
noah metz 8cb97d2350 update 2024-03-31 18:58:27 -07:00
noah metz c29981da20 Updated MessageQueue 2024-03-31 17:02:30 -07:00
noah metz 810e17990c Made default queue size 2, and not sure how that sped up the test to pre-change speeds 2024-03-31 15:24:34 -07:00
noah metz d0d07d464d Remove debug print 2024-03-31 15:19:08 -07:00
noah metz 11e7df2bde Changed from a static channel queue to a dynamic queue for nodes 2024-03-31 15:18:47 -07:00
noah metz 3eee736f97 Moved SendMsg and RecvMsg to one object 2024-03-30 23:36:50 -07:00
noah metz 7e157068d7 Removed database update every signal process(need to find a replacement still), updated graphiql cmd, and made lockable more efficient at high numbers of requirements 2024-03-30 22:57:18 -07:00
noah metz b2d84b2453 Moved db from badger to an interface 2024-03-30 14:42:06 -07:00
noah metz 66d5e3f260 Changed serialization to not allocate any memory, expects to be passed enough memory to serialize the type 2024-03-28 20:23:22 -07:00
noah metz 1eff534e1a Fixes and optimizations 2024-03-28 19:28:07 -07:00
noah metz 3d28c703db Removed event.go 2024-03-25 18:50:53 -06:00
noah metz a4115a4f99 Updated gql subscriptions to be send less messages 2024-03-25 18:49:52 -06:00
noah metz ab76f09923 Got GQL subscriptions working for lockable_state 2024-03-23 03:23:00 -06:00
noah metz 6850031e80 Removed log lines, and fixed lock fail logic 2024-03-23 02:51:46 -06:00
noah metz 0b93c90aa9 Removed uneeded lockable field 2024-03-23 02:23:42 -06:00
noah metz 2db4655670 Rewrote lockable.go 2024-03-23 02:21:27 -06:00
noah metz d7b07df798 Reorganized to cleanup 2024-03-21 14:22:34 -06:00
noah metz 0bced58fd1 Fixed GQL issues, started docs 2024-03-21 14:13:54 -06:00
noah metz 8f9a759b26 Added GQL enum 2024-03-17 14:25:34 -06:00
noah metz c0407b094c Enabled GQL tests and got them working to a bare minimum 2024-03-10 16:31:14 -06:00
21 changed files with 1923 additions and 1702 deletions

@ -17,7 +17,9 @@ func main() {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
check(err) check(err)
ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"})) ctx, err := gv.NewContext(&gv.BadgerDB{
DB: db,
}, gv.NewConsoleLogger([]string{"test"}))
check(err) check(err)
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil) gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
@ -25,13 +27,16 @@ func main() {
listener_ext := gv.NewListenerExt(1000) listener_ext := gv.NewListenerExt(1000)
n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil)) n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
check(err) check(err)
n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID})) n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
check(err) check(err)
_, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID})) n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil))
check(err)
_, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID}))
check(err) check(err)
for true { for true {

File diff suppressed because it is too large Load Diff

193
db.go

@ -3,153 +3,226 @@ package graphvent
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"reflect"
"sync"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
func WriteNodeInit(ctx *Context, node *Node) error { type Database interface {
WriteNodeInit(*Context, *Node) error
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
LoadNode(*Context, NodeID) (*Node, error)
}
const WRITE_BUFFER_SIZE = 1000000
type BadgerDB struct {
*badger.DB
sync.Mutex
buffer [WRITE_BUFFER_SIZE]byte
}
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
if node == nil { if node == nil {
return fmt.Errorf("Cannot serialize nil *Node") return fmt.Errorf("Cannot serialize nil *Node")
} }
return ctx.DB.Update(func(tx *badger.Txn) error { return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes // Get the base key bytes
id_ser, err := node.ID.MarshalBinary() id_ser, err := node.ID.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
cur := 0
// Write Node value // Write Node value
node_val, err := Serialize(ctx, node) written, err := Serialize(ctx, node, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
err = tx.Set(id_ser, node_val)
err = tx.Set(id_ser, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// Write empty signal queue // Write empty signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_val, err := Serialize(ctx, node.SignalQueue) written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
err = tx.Set(sigqueue_id, sigqueue_val)
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// Write node extension list // Write node extension list
ext_list := []ExtType{} ext_list := []ExtType{}
for ext_type := range(node.Extensions) { for ext_type := range(node.Extensions) {
ext_list = append(ext_list, ext_type) ext_list = append(ext_list, ext_type)
} }
ext_list_val, err := Serialize(ctx, ext_list) written, err = Serialize(ctx, ext_list, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
ext_list_id := append(id_ser, []byte(" - EXTLIST")...) ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
err = tx.Set(ext_list_id, ext_list_val) err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// For each extension: // For each extension:
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
// Write each extension's current value ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
}
ext_value := reflect.ValueOf(ext).Elem()
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_val, err := Serialize(ctx, ext)
if err != nil { // Write each field to a seperate key
return err for field_tag, field_info := range(ext_info.Fields) {
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
} }
err = tx.Set(ext_id, ext_val)
} }
return nil return nil
}) })
} }
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
return ctx.DB.Update(func(tx *badger.Txn) error { return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes // Get the base key bytes
id_ser, err := node.ID.MarshalBinary() id_bytes := ([16]byte)(node.ID)
if err != nil {
return err cur := 0
}
// Write the signal queue if it needs to be written // Write the signal queue if it needs to be written
if node.writeSignalQueue { if node.writeSignalQueue {
node.writeSignalQueue = false node.writeSignalQueue = false
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
sigqueue_val, err := Serialize(ctx, node.SignalQueue) written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil { if err != nil {
return err return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
} }
err = tx.Set(sigqueue_id, sigqueue_val) err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
} }
cur += written
} }
// For each ext in changes // For each ext in changes
for ext_type := range(changes) { for ext_type, changes := range(changes) {
// Write each ext ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext, exists := node.Extensions[ext_type] ext, exists := node.Extensions[ext_type]
if exists == false { if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
} }
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
ext_ser, err := Serialize(ctx, ext) ext_value := reflect.ValueOf(ext)
if err != nil {
return err
}
err = tx.Set(ext_id, ext_ser) // Write each field
if err != nil { for _, tag := range(changes) {
return err field_info, exists := ext_info.Fields[tag]
if exists == false {
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
}
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
} }
} }
return nil return nil
}) })
} }
func LoadNode(ctx *Context, id NodeID) (*Node, error) { func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
var node *Node = nil var node *Node = nil
err := ctx.DB.View(func(tx *badger.Txn) error {
err := db.View(func(tx *badger.Txn) error {
// Get the base key bytes // Get the base key bytes
id_ser, err := id.MarshalBinary() id_ser, err := id.MarshalBinary()
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to serialize node_id: %w", err)
} }
// Get the node value // Get the node value
node_item, err := tx.Get(id_ser) node_item, err := tx.Get(id_ser)
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
} }
err = node_item.Value(func(val []byte) error { err = node_item.Value(func(val []byte) error {
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
node, err = Deserialize[*Node](ctx, val) node, err = Deserialize[*Node](ctx, val)
return err return err
}) })
if err != nil { if err != nil {
return nil return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
} }
// Get the signal queue // Get the signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_item, err := tx.Get(sigqueue_id) sigqueue_item, err := tx.Get(sigqueue_id)
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to get sigqueue_id: %w", err)
} }
err = sigqueue_item.Value(func(val []byte) error { err = sigqueue_item.Value(func(val []byte) error {
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val) node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
return err return err
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
} }
// Get the extension list // Get the extension list
@ -168,20 +241,34 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) {
// Get the extensions // Get the extensions
for _, ext_type := range(ext_list) { for _, ext_type := range(ext_list) {
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_item, err := tx.Get(ext_id) ext_info, exists := ctx.Extensions[ext_type]
if err != nil { if exists == false {
return err return fmt.Errorf("Extension %s not in context", ext_type)
} }
var ext Extension ext := reflect.New(ext_info.Type)
err = ext_item.Value(func(val []byte) error { for field_tag, field_info := range(ext_info.Fields) {
ext, err = Deserialize[Extension](ctx, val) field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
return err field_item, err := tx.Get(field_id)
}) if err != nil {
if err != nil { return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
return err }
err = field_item.Value(func(val []byte) error {
value, _, err := DeserializeValue(ctx, val, field_info.Type)
if err != nil {
return err
}
ext.Elem().FieldByIndex(field_info.Index).Set(value)
return nil
})
if err != nil {
return err
}
} }
node.Extensions[ext_type] = ext
node.Extensions[ext_type] = ext.Interface().(Extension)
} }
return nil return nil
@ -189,6 +276,8 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} else if node == nil {
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
} }
return node, nil return node, nil

@ -1,156 +0,0 @@
package graphvent
import (
"time"
"fmt"
)
type EventCommand string
type EventState string
type EventExt struct {
Name string `gv:"name"`
State EventState `gv:"state"`
StateStart time.Time `gv:"state_start"`
Parent NodeID `gv:"parent" node:"Base"`
}
func (ext *EventExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *EventExt) Unload(ctx *Context, node *Node) {
}
func NewEventExt(parent NodeID, name string) *EventExt {
return &EventExt{
Name: name,
State: "init",
Parent: parent,
}
}
type EventStateSignal struct {
SignalHeader
Source NodeID `gv:"source"`
State EventState `gv:"state"`
Time time.Time `gv:"time"`
}
func (signal EventStateSignal) String() string {
return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time)
}
func NewEventStateSignal(source NodeID, state EventState, t time.Time) *EventStateSignal {
return &EventStateSignal{
SignalHeader: NewSignalHeader(),
Source: source,
State: state,
Time: t,
}
}
type EventControlSignal struct {
SignalHeader
Command EventCommand `gv:"command"`
}
func (signal EventControlSignal) String() string {
return fmt.Sprintf("EventControlSignal(%s, %s)", signal.SignalHeader, signal.Command)
}
func NewEventControlSignal(command EventCommand) *EventControlSignal {
return &EventControlSignal{
NewSignalHeader(),
command,
}
}
func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) {
if ext.State != state {
ext.StateStart = state_start
changes.Add("state")
ext.State = state
node.QueueSignal(time.Now(), NewEventStateSignal(node.ID, ext.State, time.Now()))
}
}
func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
return messages, changes
}
type TestEventExt struct {
Length time.Duration
}
func (ext *TestEventExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *TestEventExt) Unload(ctx *Context, node *Node) {
}
type EventCommandMap map[EventCommand]map[EventState]EventState
var test_event_commands = EventCommandMap{
"ready?": {
"init": "ready",
},
"start": {
"ready": "running",
},
"abort": {
"ready": "init",
},
"stop": {
"running": "stopped",
},
"finish": {
"running": "done",
},
}
func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
switch sig := signal.(type) {
case *EventControlSignal:
event_ext, err := GetExt[EventExt](node)
if err != nil {
messages = append(messages, SendMsg{source, NewErrorSignal(sig.Id, "not_event")})
} else {
ctx.Log.Logf("event", "%s got %s EventControlSignal while in %s", node.ID, sig.Command, event_ext.State)
new_state, error_signal := event_ext.ValidateEventCommand(sig, test_event_commands)
if error_signal != nil {
messages = append(messages, SendMsg{source, error_signal})
} else {
switch sig.Command {
case "start":
node.QueueSignal(time.Now().Add(ext.Length), NewEventControlSignal("finish"))
}
event_ext.UpdateState(node, changes, new_state, time.Now())
messages = append(messages, SendMsg{source, NewSuccessSignal(sig.Id)})
}
}
}
return messages, changes
}
func(ext *EventExt) ValidateEventCommand(signal *EventControlSignal, commands EventCommandMap) (EventState, *ErrorSignal) {
transitions, command_mapped := commands[signal.Command]
if command_mapped == false {
return "", NewErrorSignal(signal.Id, "unknown command %s", signal.Command)
} else {
new_state, valid_transition := transitions[ext.State]
if valid_transition == false {
return "", NewErrorSignal(signal.Id, "invalid command state %s(%s)", signal.Command, ext.State)
} else {
return new_state, nil
}
}
}

@ -1,86 +0,0 @@
package graphvent
import (
"crypto/ed25519"
"testing"
"time"
"crypto/rand"
)
func TestEvent(t *testing.T) {
ctx := logTestContext(t, []string{"event", "listener", "listener_debug"})
err := RegisterExtension[TestEventExt](ctx, nil)
fatalErr(t, err)
err = RegisterObject[TestEventExt](ctx)
fatalErr(t, err)
event_public, event_private, err := ed25519.GenerateKey(rand.Reader)
event_listener := NewListenerExt(100)
event, err := NewNode(ctx, event_private, "Base", 100, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener)
fatalErr(t, err)
response, signals := testSend(t, ctx, NewEventControlSignal("ready?"), event, event)
switch resp := response.(type) {
case *SuccessSignal:
case *ErrorSignal:
t.Fatalf("Error response %+v", resp.Error)
default:
t.Fatalf("Unexpected response %+v", resp)
}
var state_signal *EventStateSignal = nil
for _, signal := range(signals) {
event_state, is_event_state := signal.(*EventStateSignal)
if is_event_state == true && event_state.Source == event.ID && event_state.State == "ready" {
state_signal = event_state
break
}
}
if state_signal == nil {
state_signal, err = WaitForSignal(event_listener.Chan, 10*time.Millisecond, func(sig *EventStateSignal) bool {
return sig.Source == event.ID && sig.State == "ready"
})
fatalErr(t, err)
}
response, signals = testSend(t, ctx, NewEventControlSignal("start"), event, event)
switch resp := response.(type) {
case *SuccessSignal:
case *ErrorSignal:
t.Fatalf("Error response %+v", resp.Error)
default:
t.Fatalf("Unexpected response %+v", resp)
}
state_signal = nil
for _, signal := range(signals) {
event_state, is_event_state := signal.(*EventStateSignal)
if is_event_state == true && event_state.Source == event.ID && event_state.State == "running" {
state_signal = event_state
break
}
}
if state_signal == nil {
state_signal, err = WaitForSignal(event_listener.Chan, 10*time.Millisecond, func(sig *EventStateSignal) bool {
return sig.Source == event.ID && sig.State == "running"
})
fatalErr(t, err)
}
_, err = WaitForSignal(event_listener.Chan, time.Second * 2, func(sig *EventStateSignal) bool {
return sig.Source == event.ID && sig.State == "done"
})
fatalErr(t, err)
response, signals = testSend(t, ctx, NewEventControlSignal("start"), event, event)
switch resp := response.(type) {
case *SuccessSignal:
t.Fatalf("Success response starting finished TestEventExt")
case *ErrorSignal:
default:
t.Fatalf("Unexpected response %+v", resp)
}
}

@ -1,13 +1,12 @@
package graphvent package graphvent
import ( type Tag string
type Changes []Tag
)
// Extensions are data attached to nodes that process signals // Extensions are data attached to nodes that process signals
type Extension interface { type Extension interface {
// Called to process incoming signals, returning changes and messages to send // Called to process incoming signals, returning changes and messages to send
Process(*Context, *Node, NodeID, Signal) ([]SendMsg, Changes) Process(*Context, *Node, NodeID, Signal) ([]Message, Changes)
// Called when the node is loaded into a context(creation or move), so extension data can be initialized // Called when the node is loaded into a context(creation or move), so extension data can be initialized
Load(*Context, *Node) error Load(*Context, *Node) error
@ -15,10 +14,3 @@ type Extension interface {
// Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up // Called when the node is unloaded from a context(deletion or move), so extension data can be cleaned up
Unload(*Context, *Node) Unload(*Context, *Node)
} }
// Changes are lists of modifications made to extensions to be communicated
type Changes []string
func (changes *Changes) Add(fields ...string) {
new_changes := append(*changes, fields...)
changes = &new_changes
}

@ -249,13 +249,14 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
for header, value := range(r.Header) { for header, value := range(r.Header) {
header_map[header] = value header_map[header] = value
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext) resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized("")) json.NewEncoder(w).Encode(GQLUnauthorized(""))
return return
} else {
ctx.Log.Logf("gql", "New Query: %s", resolve_context.ID)
} }
req_ctx := context.Background() req_ctx := context.Background()
@ -304,7 +305,6 @@ func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result {
return resultChannel return resultChannel
} }
func getOperationTypeOfReq(p graphql.Params) string{ func getOperationTypeOfReq(p graphql.Params) string{
source := source.NewSource(&source.Source{ source := source.NewSource(&source.Source{
Body: []byte(p.RequestString), Body: []byte(p.RequestString),
@ -330,18 +330,6 @@ func getOperationTypeOfReq(p graphql.Params) string{
return "END_OF_FUNCTION" return "END_OF_FUNCTION"
} }
func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
operation := getOperationTypeOfReq(p)
ctx.Log.Logf("gqlws", "GQLWSDO_OPERATION: %s - %+v", operation, p.RequestString)
if operation == ast.OperationTypeSubscription {
return graphql.Subscribe(p)
}
res := graphql.Do(p)
return sendOneResultAndClose(res)
}
func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
@ -351,11 +339,12 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
header_map[header] = value header_map[header] = value
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext) resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return return
} else {
ctx.Log.Logf("gql", "New Subscription: %s", resolve_context.ID)
} }
req_ctx := context.Background() req_ctx := context.Background()
@ -429,11 +418,14 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
params.VariableValues = msg.Payload.Variables params.VariableValues = msg.Payload.Variables
} }
res_chan := GQLWSDo(ctx, params) var res_chan chan *graphql.Result
if res_chan == nil { operation := getOperationTypeOfReq(params)
ctx.Log.Logf("gqlws", "res_chan is nil")
if operation == ast.OperationTypeSubscription {
res_chan = graphql.Subscribe(params)
} else { } else {
ctx.Log.Logf("gqlws", "res_chan: %+v", res_chan) res := graphql.Do(params)
res_chan = sendOneResultAndClose(res)
} }
go func(res_chan chan *graphql.Result) { go func(res_chan chan *graphql.Result) {
@ -509,7 +501,7 @@ type Field struct {
type NodeResult struct { type NodeResult struct {
NodeID NodeID NodeID NodeID
NodeType NodeType NodeType NodeType
Data map[ExtType]map[string]interface{} Data map[string]interface{}
} }
type ListField struct { type ListField struct {
@ -526,6 +518,7 @@ type SelfField struct {
type SubscriptionInfo struct { type SubscriptionInfo struct {
ID uuid.UUID ID uuid.UUID
NodeCache *map[NodeID]NodeResult
Channel chan interface{} Channel chan interface{}
} }
@ -544,11 +537,13 @@ type GQLExt struct {
State string `gv:"state"` State string `gv:"state"`
TLSKey []byte `gv:"tls_key"` TLSKey []byte `gv:"tls_key"`
TLSCert []byte `gv:"tls_cert"` TLSCert []byte `gv:"tls_cert"`
Listen string `gv:"listen"` Listen string `gv:"listen" gql:"GQLListen"`
} }
func (ext *GQLExt) Load(ctx *Context, node *Node) error { func (ext *GQLExt) Load(ctx *Context, node *Node) error {
ctx.Log.Logf("gql", "Loading GQL server extension on %s", node.ID) ctx.Log.Logf("gql", "Loading GQL server extension on %s", node.ID)
ext.resolver_response = map[uuid.UUID]chan Signal{}
ext.subscriptions = []SubscriptionInfo{}
return ext.StartGQLServer(ctx, node) return ext.StartGQLServer(ctx, node)
} }
@ -562,14 +557,7 @@ func (ext *GQLExt) Unload(ctx *Context, node *Node) {
} }
} }
func (ext *GQLExt) PostDeserialize(*Context) error { func (ext *GQLExt) AddSubscription(id uuid.UUID, ctx *ResolveContext) (chan interface{}, error) {
ext.resolver_response = map[uuid.UUID]chan Signal{}
ext.subscriptions = []SubscriptionInfo{}
return nil
}
func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
ext.subscriptions_lock.Lock() ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock() defer ext.subscriptions_lock.Unlock()
@ -583,6 +571,7 @@ func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{ ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{
id, id,
&ctx.NodeCache,
c, c,
}) })
@ -630,10 +619,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
return response_chan return response_chan
} }
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
var changes Changes = nil var changes Changes = nil
var messages []SendMsg = nil var messages []Message = nil
switch sig := signal.(type) { switch sig := signal.(type) {
case *SuccessSignal: case *SuccessSignal:
@ -645,8 +634,6 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
default: default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
} }
} else {
ctx.Log.Logf("gql", "received success signal response %+v with no mapped resolver", sig)
} }
case *ErrorSignal: case *ErrorSignal:
@ -659,9 +646,6 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
default: default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
} }
} else {
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
} }
case *ReadResultSignal: case *ReadResultSignal:
@ -669,23 +653,22 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
if response_chan != nil { if response_chan != nil {
select { select {
case response_chan <- sig: case response_chan <- sig:
ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig)
default: default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
} }
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
} }
case *StatusSignal: case *StatusSignal:
ext.subscriptions_lock.RLock() ext.subscriptions_lock.RLock()
ctx.Log.Logf("gql", "forwarding status signal from %+v to resolvers %+v", sig.Source, ext.subscriptions) for _, sub := range(ext.subscriptions) {
for _, resolver := range(ext.subscriptions) { _, cached := (*sub.NodeCache)[sig.Source]
select { if cached {
case resolver.Channel <- sig: select {
ctx.Log.Logf("gql_subscribe", "forwarded status signal to resolver: %+v", resolver.ID) case sub.Channel <- sig:
default: ctx.Log.Logf("gql", "forwarded status signal %+v to subscription: %s", sig, sub.ID)
ctx.Log.Logf("gql_subscribe", "resolver channel overflow: %+v", resolver.ID) default:
ctx.Log.Logf("gql", "subscription channel overflow: %s", sub.ID)
}
} }
} }
ext.subscriptions_lock.RUnlock() ext.subscriptions_lock.RUnlock()

@ -25,7 +25,12 @@ func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
return uint64(node.NodeType), nil return uint64(node.NodeType), nil
} }
func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { type FieldIndex struct {
Extension ExtType
Tag string
}
func GetFields(selection_set *ast.SelectionSet) []string {
names := []string{} names := []string{}
if selection_set == nil { if selection_set == nil {
return names return names
@ -34,10 +39,12 @@ func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string {
for _, sel := range(selection_set.Selections) { for _, sel := range(selection_set.Selections) {
switch field := sel.(type) { switch field := sel.(type) {
case *ast.Field: case *ast.Field:
if field.Name.Value == "ID" || field.Name.Value == "Type" {
continue
}
names = append(names, field.Name.Value) names = append(names, field.Name.Value)
case *ast.InlineFragment: case *ast.InlineFragment:
default: names = append(names, GetFields(field.SelectionSet)...)
ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field))
} }
} }
@ -45,48 +52,13 @@ func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string {
} }
// Returns the fields that need to be resolved // Returns the fields that need to be resolved
func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) (map[ExtType][]string, error) { func GetResolveFields(p graphql.ResolveParams) []string {
node_info, mapped := ctx.Context.NodeTypes[p.Info.ReturnType.Name()] fields := []string{}
if mapped == false {
return nil, fmt.Errorf("No NodeType %s", p.Info.ReturnType.Name())
}
fields := map[ExtType][]string{}
names := []string{}
for _, field := range(p.Info.FieldASTs) { for _, field := range(p.Info.FieldASTs) {
names = append(names, GetFieldNames(ctx.Context, field.SelectionSet)...) fields = append(fields, GetFields(field.SelectionSet)...)
} }
cache, node_cached := ctx.NodeCache[id] return fields
for _, name := range(names) {
if name == "ID" || name == "Type" {
continue
}
ext_type, field_mapped := node_info.Fields[name]
if field_mapped == false {
return nil, fmt.Errorf("NodeType %s does not have field %s", p.Info.ReturnType.Name(), name)
}
ext_fields, exists := fields[ext_type]
if exists == false {
ext_fields = []string{}
}
if node_cached {
ext_cache, ext_cached := cache.Data[ext_type]
if ext_cached {
_, field_cached := ext_cache[name]
if field_cached {
continue
}
}
}
fields[ext_type] = append(ext_fields, name)
}
return fields, nil
} }
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
@ -95,58 +67,81 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
return NodeResult{}, err return NodeResult{}, err
} }
fields, err := GetResolveFields(id, ctx, p) switch source := p.Source.(type) {
if err != nil { case *StatusSignal:
return NodeResult{}, err cached_node, cached := ctx.NodeCache[source.Source]
if cached {
for _, field_name := range(source.Fields) {
_, cached := cached_node.Data[field_name]
if cached {
delete(cached_node.Data, field_name)
}
}
ctx.NodeCache[source.Source] = cached_node
}
} }
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", fields, id) cache, node_cached := ctx.NodeCache[id]
fields := GetResolveFields(p)
var not_cached []string
if node_cached {
not_cached = []string{}
for _, field := range(fields) {
if node_cached {
_, field_cached := cache.Data[field]
if field_cached {
continue
}
}
signal := NewReadSignal(fields) not_cached = append(not_cached, field)
response_chan := ctx.Ext.GetResponseChannel(signal.ID()) }
// TODO: TIMEOUT DURATION } else {
err = ctx.Context.Send(ctx.Server, []SendMsg{{ not_cached = fields
Dest: id,
Signal: signal,
}})
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
} }
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID()) if (len(not_cached) == 0) && (node_cached == true) {
ctx.Ext.FreeResponseChannel(signal.ID()) ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
if err != nil { return cache, nil
return NodeResult{}, err } else {
} ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
signal := NewReadSignal(not_cached)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
// TODO: TIMEOUT DURATION
err = ctx.Context.Send(ctx.Server, []Message{{
Node: id,
Signal: signal,
}})
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
}
switch response := response.(type) { response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
case *ReadResultSignal: ctx.Ext.FreeResponseChannel(signal.ID())
cache, node_cached := ctx.NodeCache[id] if err != nil {
if node_cached == false { return NodeResult{}, err
cache = NodeResult{ }
NodeID: id,
NodeType: response.NodeType,
Data: response.Extensions,
}
} else {
for ext_type, ext_data := range(response.Extensions) {
cached_ext, ext_cached := cache.Data[ext_type]
if ext_cached {
for field_name, field := range(ext_data) {
cache.Data[ext_type][field_name] = field
}
} else {
cache.Data[ext_type] = ext_data
}
cache.Data[ext_type] = cached_ext switch response := response.(type) {
case *ReadResultSignal:
if node_cached == false {
cache = NodeResult{
NodeID: id,
NodeType: response.NodeType,
Data: response.Fields,
}
} else {
for field_name, field_value := range(response.Fields) {
cache.Data[field_name] = field_value
}
} }
}
ctx.NodeCache[id] = cache ctx.NodeCache[id] = cache
return ctx.NodeCache[id], nil return ctx.NodeCache[id], nil
default: default:
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response) return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
}
} }
} }

@ -1,162 +1,49 @@
package graphvent package graphvent
/*import ( import (
"testing" "bytes"
"time" "crypto/tls"
"fmt" "encoding/json"
"encoding/json" "fmt"
"io" "io"
"net/http" "net"
"net" "net/http"
"crypto/tls" "reflect"
"crypto/rand" "testing"
"crypto/ed25519" "time"
"bytes"
"golang.org/x/net/websocket" "github.com/google/uuid"
"github.com/google/uuid" "golang.org/x/net/websocket"
) )
func TestGQLAuth(t *testing.T) { func TestGQLSubscribe(t *testing.T) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test", "gql"})
listener_1 := NewListenerExt(10) n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
node_1, err := NewNode(ctx, nil, "Base", 10, nil, listener_1)
fatalErr(t, err) fatalErr(t, err)
listener_2 := NewListenerExt(10) listener_ext := NewListenerExt(10)
node_2, err := NewNode(ctx, nil, "Base", 10, nil, listener_2)
fatalErr(t, err)
auth_header, err := AuthB64(node_1.Key, node_2.Key.Public().(ed25519.PublicKey))
fatalErr(t, err)
auth, err := ParseAuthB64(auth_header, node_2.Key)
fatalErr(t, err)
err = ValidateAuthorization(Authorization{
AuthInfo: auth.AuthInfo,
Key: auth.Key.Public().(ed25519.PublicKey),
}, time.Second)
fatalErr(t, err)
ctx.Log.Logf("test", "AUTH: %+v", auth)
}
func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gqlws", "gql"})
pub, gql_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
gql_id := KeyID(pub)
group_policy_1 := NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ReadSignal]()): Tree{
SerializedType(ExtTypeFor[GroupExt]()): Tree{
SerializedType(GetFieldTag("members")): Tree{},
},
},
SerializedType(SignalTypeFor[ReadResultSignal]()): nil,
SerializedType(SignalTypeFor[ErrorSignal]()): nil,
})
group_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{
gql_id: {
"test_group": {
SerializedType(SignalTypeFor[LinkSignal]()): nil,
SerializedType(SignalTypeFor[LockSignal]()): nil,
SerializedType(SignalTypeFor[StatusSignal]()): nil,
SerializedType(SignalTypeFor[ReadSignal]()): nil,
},
},
})
user_policy_1 := NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ReadResultSignal]()): nil,
SerializedType(SignalTypeFor[ErrorSignal]()): nil,
})
user_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{
gql_id: {
"test_group": {
SerializedType(SignalTypeFor[LinkSignal]()): nil,
SerializedType(SignalTypeFor[ReadSignal]()): nil,
SerializedType(SignalTypeFor[LockSignal]()): nil,
},
},
})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
n1, err := NewNode(ctx, nil, "Base", 10, []Policy{user_policy_2, user_policy_1}, NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
gql, err := NewNode(ctx, gql_key, "Base", 10, []Policy{group_policy_2, group_policy_1}, query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }"
NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[string][]NodeID{"test_group": {n1.ID, gql_id}}), listener_ext)
fatalErr(t, err)
ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID) ctx.Log.Logf("test", "Node: %s", n1.ID)
ctx.Log.Logf("test", "Query: %s", query)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Source == gql_id
})
fatalErr(t, err)
skipVerifyTransport := &http.Transport{ sub_1 := GQLPayload{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, Query: query,
} }
client := &http.Client{Transport: skipVerifyTransport}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("http://localhost:%d/gql", port) url := fmt.Sprintf("http://localhost:%d/gql", port)
ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port) ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port)
req_1 := GQLPayload{
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }",
Variables: map[string]interface{}{
"id": n1.ID.String(),
},
}
req_2 := GQLPayload{
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { SubGroups { Name, Members { ID } } , Listen, Requirements { ID, TypeHash Owner { ID } } } } }",
Variables: map[string]interface{}{
"id": gql.ID.String(),
},
}
auth_header, err := AuthB64(n1.Key, gql.Key.Public().(ed25519.PublicKey))
fatalErr(t, err)
SendGQL := func(payload GQLPayload) []byte {
ser, err := json.MarshalIndent(&payload, "", " ")
fatalErr(t, err)
req_data := bytes.NewBuffer(ser)
req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err)
req.Header.Add("Authorization", auth_header)
resp, err := client.Do(req)
fatalErr(t, err)
body, err := io.ReadAll(resp.Body)
fatalErr(t, err)
resp.Body.Close()
return body
}
resp_1 := SendGQL(req_1)
ctx.Log.Logf("test", "RESP_1: %s", resp_1)
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
sub_1 := GQLPayload{
Query: "subscription { Self { ID, TypeHash, ... on Lockable { Requirements { ID }}}}",
}
SubGQL := func(payload GQLPayload) { SubGQL := func(payload GQLPayload) {
config, err := websocket.NewConfig(ws_url, url) config, err := websocket.NewConfig(ws_url, url)
fatalErr(t, err) fatalErr(t, err)
@ -174,11 +61,9 @@ func TestGQLServer(t *testing.T) {
init := struct{ init := struct{
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Payload payload_struct `json:"payload"`
}{ }{
uuid.New(), uuid.New(),
"connection_init", "connection_init",
payload_struct{ auth_header },
} }
ser, err := json.Marshal(&init) ser, err := json.Marshal(&init)
@ -211,75 +96,128 @@ func TestGQLServer(t *testing.T) {
n, err = ws.Read(resp) n, err = ws.Read(resp)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "SUB: %s", resp[:n]) ctx.Log.Logf("test", "SUB1: %s", resp[:n])
msgs := Messages{} lock_id, err := LockLockable(ctx, gql)
test_changes := Changes{}
AddChange[GQLExt](test_changes, "state")
msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, test_changes))
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
response, _, err := WaitForResponse(listener_ext.Chan, 100*time.Millisecond, lock_id)
fatalErr(t, err)
switch response.(type) {
case *SuccessSignal:
ctx.Log.Logf("test", "Locked %s", gql.ID)
default:
t.Errorf("Unexpected lock response: %s", response)
}
n, err = ws.Read(resp) n, err = ws.Read(resp)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "SUB: %s", resp[:n]) ctx.Log.Logf("test", "SUB2: %s", resp[:n])
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB3: %s", resp[:n])
// TODO: check that there are no more messages sent to ws within a timeout // TODO: check that there are no more messages sent to ws within a timeout
} }
SubGQL(sub_1) SubGQL(sub_1)
}
msgs := Messages{} func TestGQLQuery(t *testing.T) {
msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal()) ctx := logTestContext(t, []string{"test", "lockable"})
err = ctx.Send(msgs)
n1_listener := NewListenerExt(10)
n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil), n1_listener)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {
return sig.Source == gql_id gql_listener := NewListenerExt(10)
}) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
}
func TestGQLDB(t *testing.T) { gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener)
ctx := logTestContext(t, []string{"test", "db", "node"}) fatalErr(t, err)
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID)
skipVerifyTransport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: skipVerifyTransport}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("http://localhost:%d/gql", port)
req_1 := GQLPayload{
Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { LockableState } } }",
Variables: map[string]interface{}{
"id": n1.ID.String(),
},
}
req_2 := GQLPayload{
Query: "query Self { Self { ID, Type, ... on Lockable { LockableState, Requirements { Key { ID ... on Lockable { LockableState } } } } } }",
}
SendGQL := func(payload GQLPayload) []byte {
ser, err := json.MarshalIndent(&payload, "", " ")
fatalErr(t, err)
req_data := bytes.NewBuffer(ser)
req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err)
resp, err := client.Do(req)
fatalErr(t, err)
body, err := io.ReadAll(resp.Body)
fatalErr(t, err)
resp.Body.Close()
return body
}
u1, err := NewNode(ctx, nil, "Base", 10, nil) resp_1 := SendGQL(req_1)
ctx.Log.Logf("test", "RESP_1: %s", resp_1)
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
lock_id, err := LockLockable(ctx, n1)
fatalErr(t, err)
response, _, err := WaitForResponse(n1_listener.Chan, 100*time.Millisecond, lock_id)
fatalErr(t, err) fatalErr(t, err)
switch response := response.(type) {
case *SuccessSignal:
default:
t.Fatalf("Wrong response: %s", reflect.TypeOf(response))
}
ctx.Log.Logf("test", "U1_ID: %s", u1.ID) resp_3 := SendGQL(req_1)
ctx.Log.Logf("test", "RESP_3: %s", resp_3)
resp_4 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_4: %s", resp_4)
}
func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node", "serialize"})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql, err := NewNode(ctx, nil, "Base", 10, nil,
gql_ext, gql, err := ctx.NewNode(nil, "Node", gql_ext, listener_ext)
listener_ext,
NewGroupExt(nil))
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
msgs := Messages{} err = ctx.Stop()
msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal())
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {
return sig.Source == gql.ID
})
fatalErr(t, err) fatalErr(t, err)
// Clear all loaded nodes from the context so it loads them from the database gql_loaded, err := ctx.GetNode(gql.ID)
ctx.nodeMap = map[NodeID]*Node{}
gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[ListenerExt](gql_loaded) listener_ext, err = GetExt[ListenerExt](gql_loaded)
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{}
msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded, nil, NewStopSignal())
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {
return sig.Source == gql_loaded.ID
})
fatalErr(t, err)
} }
*/

@ -9,27 +9,20 @@ import (
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) { func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) {
listener_extension := NewListenerExt(buffer) listener_extension := NewListenerExt(buffer)
listener, err := NewNode(ctx, listener, err := ctx.NewNode(nil, "LockableNode", nil, listener_extension, NewLockableExt(nil))
nil,
"LockableListener",
10,
nil,
listener_extension,
NewLockableExt(nil))
return listener, listener_extension, err return listener, listener_extension, err
} }
func logTestContext(t * testing.T, components []string) *Context { func logTestContext(t * testing.T, components []string) *Context {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true).WithSyncWrites(true))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx, err := NewContext(db, NewConsoleLogger(components)) ctx, err := NewContext(&BadgerDB{
fatalErr(t, err) DB: db,
}, NewConsoleLogger(components))
err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()})
fatalErr(t, err) fatalErr(t, err)
return ctx return ctx
@ -50,7 +43,7 @@ func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *No
source_listener, err := GetExt[ListenerExt](source) source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err) fatalErr(t, err)
messages := []SendMsg{{destination.ID, signal}} messages := []Message{{destination.ID, signal}}
fatalErr(t, ctx.Send(source, messages)) fatalErr(t, ctx.Send(source, messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())

@ -10,12 +10,35 @@ type ListenerExt struct {
Chan chan Signal Chan chan Signal
} }
type LoadedSignal struct {
SignalHeader
}
func NewLoadedSignal() *LoadedSignal {
return &LoadedSignal{
SignalHeader: NewSignalHeader(),
}
}
type UnloadedSignal struct {
SignalHeader
}
func NewUnloadedSignal() *UnloadedSignal {
return &UnloadedSignal{
SignalHeader: NewSignalHeader(),
}
}
func (ext *ListenerExt) Load(ctx *Context, node *Node) error { func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
ext.Chan = make(chan Signal, ext.Buffer) ext.Chan = make(chan Signal, ext.Buffer)
ext.Chan <- NewLoadedSignal()
return nil return nil
} }
func (ext *ListenerExt) Unload(ctx *Context, node *Node) { func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
ext.Chan <- NewUnloadedSignal()
close(ext.Chan)
} }
// Create a new listener extension with a given buffer size // Create a new listener extension with a given buffer size
@ -27,7 +50,7 @@ func NewListenerExt(buffer int) *ListenerExt {
} }
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal)) ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal) ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
select { select {
@ -37,7 +60,7 @@ func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal
} }
switch sig := signal.(type) { switch sig := signal.(type) {
case *StatusSignal: case *StatusSignal:
ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Changes) ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields)
} }
return nil, nil return nil, nil
} }

@ -1,8 +1,7 @@
package graphvent package graphvent
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"time"
) )
type ReqState byte type ReqState byte
@ -22,399 +21,389 @@ var ReqStateStrings = map[ReqState]string {
AbortingLock: "AbortingLock", AbortingLock: "AbortingLock",
} }
func (state ReqState) String() string {
str, mapped := ReqStateStrings[state]
if mapped == false {
return "UNKNOWN_REQSTATE"
} else {
return str
}
}
type LockableExt struct{ type LockableExt struct{
State ReqState `gv:"state"` State ReqState `gv:"state"`
ReqID *uuid.UUID `gv:"req_id"` ReqID *uuid.UUID `gv:"req_id"`
Owner *NodeID `gv:"owner" node:"Base"` Owner *NodeID `gv:"owner"`
PendingOwner *NodeID `gv:"pending_owner" node:"Base"` PendingOwner *NodeID `gv:"pending_owner"`
PendingID uuid.UUID `gv:"pending_id"`
Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"` Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
WaitInfos WaitMap `gv:"wait_infos" node:":Base"`
Locked map[NodeID]any
Unlocked map[NodeID]any
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
} }
func NewLockableExt(requirements []NodeID) *LockableExt { func NewLockableExt(requirements []NodeID) *LockableExt {
var reqs map[NodeID]ReqState = nil var reqs map[NodeID]ReqState = nil
if requirements != nil { var unlocked map[NodeID]any = map[NodeID]any{}
if len(requirements) != 0 {
reqs = map[NodeID]ReqState{} reqs = map[NodeID]ReqState{}
for _, id := range(requirements) { for _, req := range(requirements) {
reqs[id] = Unlocked reqs[req] = Unlocked
unlocked[req] = nil
} }
} }
return &LockableExt{ return &LockableExt{
State: Unlocked, State: Unlocked,
Owner: nil, Owner: nil,
PendingOwner: nil, PendingOwner: nil,
Requirements: reqs, Requirements: reqs,
WaitInfos: WaitMap{}, Waiting: WaitMap{},
Locked: map[NodeID]any{},
Unlocked: unlocked,
} }
} }
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
signal := NewLockSignal("unlock") signal := NewUnlockSignal()
messages := []SendMsg{{node.ID, signal}} messages := []Message{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages) return signal.ID(), ctx.Send(node, messages)
} }
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
signal := NewLockSignal("lock") signal := NewLockSignal()
messages := []SendMsg{{node.ID, signal}} messages := []Message{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages) return signal.ID(), ctx.Send(node, messages)
} }
func (ext *LockableExt) Load(ctx *Context, node *Node) error { func (ext *LockableExt) Load(ctx *Context, node *Node) error {
ext.Locked = map[NodeID]any{}
ext.Unlocked = map[NodeID]any{}
for id, state := range(ext.Requirements) {
if state == Unlocked {
ext.Unlocked[id] = nil
} else if state == Locked {
ext.Locked[id] = nil
}
}
return nil return nil
} }
func (ext *LockableExt) Unload(ctx *Context, node *Node) { func (ext *LockableExt) Unload(ctx *Context, node *Node) {
return
} }
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]SendMsg, Changes) { // Handle link signal by adding/removing the requested NodeID
var messages []SendMsg = nil // returns an error if the node is not unlocked
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil var changes Changes = nil
info, info_found := node.ProcessResponse(ext.WaitInfos, signal) switch ext.State {
if info_found { case Unlocked:
state, found := ext.Requirements[info.Destination]
if found == true {
changes.Add("wait_infos")
ctx.Log.Logf("lockable", "got mapped response %+v for %+v in state %s while in %s", signal, info, ReqStateStrings[state], ReqStateStrings[ext.State])
switch ext.State {
case AbortingLock:
ext.Requirements[info.Destination] = Unlocked
all_unlocked := true
for _, state := range(ext.Requirements) {
if state != Unlocked {
all_unlocked = false
break
}
}
if all_unlocked == true {
changes.Add("state")
ext.State = Unlocked
}
case Locking:
changes.Add("state")
ext.Requirements[info.Destination] = Unlocked
unlocked := 0
for _, state := range(ext.Requirements) {
if state == Unlocked {
unlocked += 1
}
}
if unlocked == len(ext.Requirements) {
ctx.Log.Logf("lockable", "%s unlocked from error %s from %s", node.ID, signal.Error, source)
ext.State = Unlocked
} else {
ext.State = AbortingLock
for id, state := range(ext.Requirements) {
if state == Locked {
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
}
}
}
case Unlocking:
ext.Requirements[info.Destination] = Locked
all_returned := true
for _, state := range(ext.Requirements) {
if state == Unlocking {
all_returned = false
break
}
}
if all_returned == true {
ext.State = Locked
}
}
} else {
ctx.Log.Logf("lockable", "Got mapped error %s, but %s isn't a requirement", signal, info.Destination)
}
}
return messages, changes
}
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
if ext.State == Unlocked {
switch signal.Action { switch signal.Action {
case "add": case "add":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == true { if exists == true {
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "already_requirement")}) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_requirement")})
} else { } else {
if ext.Requirements == nil { if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{} ext.Requirements = map[NodeID]ReqState{}
} }
ext.Requirements[signal.NodeID] = Unlocked ext.Requirements[signal.NodeID] = Unlocked
changes.Add("requirements") changes = append(changes, "requirements")
messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())}) messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
} }
case "remove": case "remove":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == false { if exists == false {
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "can't link: not_requirement")}) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else { } else {
delete(ext.Requirements, signal.NodeID) delete(ext.Requirements, signal.NodeID)
changes.Add("requirements") changes = append(changes, "requirements")
messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())}) messages = append(messages, Message{source, NewSuccessSignal(signal.ID())})
} }
default: default:
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "unknown_action")}) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unknown_action")})
} }
} else { default:
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")}) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocked: %s", ext.State)})
} }
return messages, changes return messages, changes
} }
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]SendMsg, Changes) { // Handle an UnlockSignal by either transitioning to Unlocked state,
var messages []SendMsg = nil // sending unlock signals to requirements, or returning an error signal
var changes = Changes{} func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source NodeID, signal *UnlockSignal) ([]Message, Changes) {
if source == node.ID { var messages []Message = nil
return messages, changes var changes Changes = nil
}
info, info_found := node.ProcessResponse(ext.WaitInfos, signal) switch ext.State {
if info_found == true { case Locked:
state, found := ext.Requirements[info.Destination] if source != *ext.Owner {
if found == false { messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_owner")})
ctx.Log.Logf("lockable", "Got success signal for requirement that is no longer in the map(%s), ignoring...", info.Destination)
} else { } else {
ctx.Log.Logf("lockable", "got mapped response %+v for %+v in state %s", signal, info, ReqStateStrings[state]) if len(ext.Requirements) == 0 {
switch state { changes = append(changes, "state", "owner", "pending_owner")
case Locking:
switch ext.State {
case Locking:
ext.Requirements[info.Destination] = Locked
locked := 0
for _, s := range(ext.Requirements) {
if s == Locked {
locked += 1
}
}
if locked == len(ext.Requirements) {
ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner)
ext.State = Locked
ext.Owner = ext.PendingOwner
changes.Add("state", "owner", "requirements")
messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(ext.PendingID)})
} else {
changes.Add("requirements")
ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements))
}
case AbortingLock:
ext.Requirements[info.Destination] = Unlocking
lock_signal := NewLockSignal("unlock") ext.Owner = nil
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", info.Destination, lock_signal, 100*time.Millisecond)
messages = append(messages, SendMsg{info.Destination, lock_signal})
ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.Destination, node.ID) ext.PendingOwner = nil
}
case AbortingLock:
ctx.Log.Logf("lockable", "Got success signal in AbortingLock %s", node.ID)
fallthrough
case Unlocking:
ext.Requirements[source] = Unlocked
unlocked := 0 ext.State = Unlocked
for _, s := range(ext.Requirements) {
if s == Unlocked {
unlocked += 1
}
}
if unlocked == len(ext.Requirements) { messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
old_state := ext.State } else {
ext.State = Unlocked changes = append(changes, "state", "waiting", "requirements", "pending_owner")
ctx.Log.Logf("lockable", "WHOLE UNLOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner)
if old_state == Unlocking { ext.PendingOwner = nil
previous_owner := *ext.Owner
ext.Owner = ext.PendingOwner ext.ReqID = &signal.Id
ext.ReqID = nil
changes.Add("state", "owner", "req_id") ext.State = Unlocking
messages = append(messages, SendMsg{previous_owner, NewSuccessSignal(ext.PendingID)}) for id := range(ext.Requirements) {
} else if old_state == AbortingLock { unlock_signal := NewUnlockSignal()
changes.Add("state", "pending_owner")
messages = append(messages, SendMsg{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked")}) ext.Waiting[unlock_signal.Id] = id
ext.PendingOwner = ext.Owner ext.Requirements[id] = Unlocking
}
} else { messages = append(messages, Message{id, unlock_signal})
changes.Add("state")
ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, len(ext.Requirements))
} }
} }
} }
default:
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_locked")})
} }
return messages, changes return messages, changes
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal by either transitioning to a locked state,
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]SendMsg, Changes) { // sending lock signals to requirements, or returning an error signal
var messages []SendMsg = nil func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]Message, Changes) {
var changes = Changes{} var messages []Message = nil
var changes Changes = nil
switch ext.State {
case Unlocked:
if len(ext.Requirements) == 0 {
changes = append(changes, "state", "owner", "pending_owner")
ext.Owner = &source
ext.PendingOwner = &source
ext.State = Locked
messages = append(messages, Message{source, NewSuccessSignal(signal.Id)})
} else {
changes = append(changes, "state", "requirements", "waiting", "pending_owner")
ext.PendingOwner = &source
ext.ReqID = &signal.Id
ext.State = Locking
for id := range(ext.Requirements) {
lock_signal := NewLockSignal()
ext.Waiting[lock_signal.Id] = id
ext.Requirements[id] = Locking
messages = append(messages, Message{id, lock_signal})
}
}
default:
messages = append(messages, Message{source, NewErrorSignal(signal.Id, "not_unlocked: %s", ext.State)})
}
return messages, changes
}
// Handle an error signal by aborting the lock, or retrying the unlock
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]Message, Changes) {
var messages []Message = nil
var changes Changes = nil
id, waiting := ext.Waiting[signal.ReqID]
if waiting == true {
delete(ext.Waiting, signal.ReqID)
changes = append(changes, "waiting")
switch signal.State {
case "lock":
switch ext.State { switch ext.State {
case Unlocked: case Locking:
if len(ext.Requirements) == 0 { changes = append(changes, "state", "requirements")
ext.State = Locked
new_owner := source ext.Requirements[id] = Unlocked
ext.PendingOwner = &new_owner
ext.Owner = &new_owner
changes.Add("state", "pending_owner", "owner")
messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())})
} else {
ext.State = Locking
id := signal.ID()
ext.ReqID = &id
new_owner := source
ext.PendingOwner = &new_owner
ext.PendingID = signal.ID()
changes.Add("state", "req_id", "pending_owner", "pending_id")
for id, state := range(ext.Requirements) {
if state != Unlocked {
ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
}
lock_signal := NewLockSignal("lock") unlocked := 0
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("lock", id, lock_signal, 500*time.Millisecond) for req_id, req_state := range(ext.Requirements) {
ext.Requirements[id] = Locking // Unlock locked requirements, and count unlocked requirements
switch req_state {
case Locked:
unlock_signal := NewUnlockSignal()
messages = append(messages, SendMsg{id, lock_signal}) ext.Waiting[unlock_signal.Id] = req_id
ext.Requirements[req_id] = Unlocking
messages = append(messages, Message{req_id, unlock_signal})
case Unlocked:
unlocked += 1
} }
} }
default:
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")}) if unlocked == len(ext.Requirements) {
ctx.Log.Logf("lockable", "Tried to lock %s while %s", node.ID, ext.State) changes = append(changes, "owner", "state")
}
case "unlock":
if ext.State == Locked {
if len(ext.Requirements) == 0 {
ext.State = Unlocked ext.State = Unlocked
new_owner := source
ext.PendingOwner = nil
ext.Owner = nil ext.Owner = nil
changes.Add("state", "pending_owner", "owner") } else {
messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())}) changes = append(changes, "state")
} else if source == *ext.Owner { ext.State = AbortingLock
ext.State = Unlocking }
id := signal.ID()
ext.ReqID = &id
ext.PendingOwner = nil
ext.PendingID = signal.ID()
changes.Add("state", "pending_owner", "pending_id", "req_id")
for id, state := range(ext.Requirements) {
if state != Locked {
ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
}
lock_signal := NewLockSignal("unlock") case Unlocking:
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) unlock_signal := NewUnlockSignal()
ext.Requirements[id] = Unlocking ext.Waiting[unlock_signal.Id] = id
messages = append(messages, Message{id, unlock_signal})
case AbortingLock:
req_state := ext.Requirements[id]
// Mark failed lock as Unlocked, or retry unlock
switch req_state {
case Locking:
ext.Requirements[id] = Unlocked
messages = append(messages, SendMsg{id, lock_signal}) // Check if all requirements unlocked now
unlocked := 0
for _, req_state := range(ext.Requirements) {
if req_state == Unlocked {
unlocked += 1
}
} }
if unlocked == len(ext.Requirements) {
changes = append(changes, "owner", "state")
ext.State = Unlocked
ext.Owner = nil
}
case Unlocking:
// Handle error for unlocking requirement while unlocking by retrying unlock
unlock_signal := NewUnlockSignal()
ext.Waiting[unlock_signal.Id] = id
messages = append(messages, Message{id, unlock_signal})
} }
} else {
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_locked")})
} }
default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
} }
return messages, changes return messages, changes
} }
func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) ([]SendMsg, Changes) { // Handle a success signal by checking if all requirements have been locked/unlocked
var messages []SendMsg = nil func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]Message, Changes) {
var changes = Changes{} var messages []Message = nil
var changes Changes = nil
wait_info, found := node.ProcessResponse(ext.WaitInfos, signal)
if found == true { id, waiting := ext.Waiting[signal.ReqID]
changes.Add("wait_infos") if waiting == true {
state, found := ext.Requirements[wait_info.Destination] delete(ext.Waiting, signal.ReqID)
if found == true { changes = append(changes, "waiting")
ctx.Log.Logf("lockable", "%s timed out %s while %s was %s", wait_info.Destination, ReqStateStrings[state], node.ID, ReqStateStrings[state])
switch ext.State { switch ext.State {
case AbortingLock: case Locking:
ext.Requirements[wait_info.Destination] = Unlocked ext.Requirements[id] = Locked
all_unlocked := true ext.Locked[id] = nil
for _, state := range(ext.Requirements) { delete(ext.Unlocked, id)
if state != Unlocked {
all_unlocked = false if len(ext.Locked) == len(ext.Requirements) {
break ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked))
} changes = append(changes, "state", "owner", "req_id")
} ext.State = Locked
if all_unlocked == true {
changes.Add("state") ext.Owner = ext.PendingOwner
ext.State = Unlocked
} messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
ext.ReqID = nil
} else {
ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements))
}
case AbortingLock:
req_state := ext.Requirements[id]
switch req_state {
case Locking: case Locking:
ext.State = AbortingLock ext.Requirements[id] = Unlocking
ext.Requirements[wait_info.Destination] = Unlocked unlock_signal := NewUnlockSignal()
for id, state := range(ext.Requirements) { ext.Waiting[unlock_signal.Id] = id
if state == Locked { messages = append(messages, Message{id, unlock_signal})
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
}
}
case Unlocking: case Unlocking:
ext.Requirements[wait_info.Destination] = Locked ext.Requirements[id] = Unlocked
all_returned := true ext.Unlocked[id] = nil
for _, state := range(ext.Requirements) { delete(ext.Locked, id)
if state == Unlocking {
all_returned = false unlocked := 0
break for _, req_state := range(ext.Requirements) {
switch req_state {
case Unlocked:
unlocked += 1
} }
} }
if all_returned == true {
ext.State = Locked if unlocked == len(ext.Requirements) {
changes = append(changes, "state", "pending_owner", "req_id")
messages = append(messages, Message{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)})
ext.State = Unlocked
ext.ReqID = nil
ext.PendingOwner = nil
} }
} }
} else {
ctx.Log.Logf("lockable", "%s timed out", wait_info.Destination)
case Unlocking:
ext.Requirements[id] = Unlocked
ext.Unlocked[id] = Unlocked
delete(ext.Locked, id)
if len(ext.Unlocked) == len(ext.Requirements) {
changes = append(changes, "state", "owner", "req_id")
messages = append(messages, Message{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
ext.State = Unlocked
ext.ReqID = nil
ext.Owner = nil
}
} }
} }
return messages, changes return messages, changes
} }
// LockableExts process status signals by forwarding them to it's owner func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]Message, Changes) {
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state var messages []Message = nil
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { var changes Changes = nil
var messages []SendMsg = nil
var changes = Changes{}
switch sig := signal.(type) { switch sig := signal.(type) {
case *StatusSignal: case *StatusSignal:
// Forward StatusSignals up to the owner(unless that would be a cycle)
if ext.Owner != nil { if ext.Owner != nil {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
messages = append(messages, SendMsg{*ext.Owner, signal}) messages = append(messages, Message{*ext.Owner, signal})
} }
} }
case *LinkSignal: case *LinkSignal:
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig) messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
case *LockSignal: case *LockSignal:
messages, changes = ext.HandleLockSignal(ctx, node, source, sig) messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
case *UnlockSignal:
messages, changes = ext.HandleUnlockSignal(ctx, node, source, sig)
case *ErrorSignal: case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig) messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case *SuccessSignal: case *SuccessSignal:
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig) messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
case *TimeoutSignal:
messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig)
default:
} }
return messages, changes return messages, changes

@ -10,16 +10,16 @@ func TestLink(t *testing.T) {
l2_listener := NewListenerExt(10) l2_listener := NewListenerExt(10)
l2, err := NewNode(ctx, nil, "Lockable", 10, l2_listener, NewLockableExt(nil)) l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
l1_lockable := NewLockableExt(nil) l1_lockable := NewLockableExt(nil)
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
l1, err := NewNode(ctx, nil, "Lockable", 10, l1_listener, l1_lockable) l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
fatalErr(t, err) fatalErr(t, err)
link_signal := NewLinkSignal("add", l2.ID) link_signal := NewLinkSignal("add", l2.ID)
msgs := []SendMsg{{l1.ID, link_signal}} msgs := []Message{{l1.ID, link_signal}}
err = ctx.Send(l1, msgs) err = ctx.Send(l1, msgs)
fatalErr(t, err) fatalErr(t, err)
@ -34,7 +34,7 @@ func TestLink(t *testing.T) {
} }
unlink_signal := NewLinkSignal("remove", l2.ID) unlink_signal := NewLinkSignal("remove", l2.ID)
msgs = []SendMsg{{l1.ID, unlink_signal}} msgs = []Message{{l1.ID, unlink_signal}}
err = ctx.Send(l1, msgs) err = ctx.Send(l1, msgs)
fatalErr(t, err) fatalErr(t, err)
@ -42,24 +42,40 @@ func TestLink(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
} }
func Test10Lock(t *testing.T) {
testLockN(t, 10)
}
func Test100Lock(t *testing.T) {
testLockN(t, 100)
}
func Test1000Lock(t *testing.T) { func Test1000Lock(t *testing.T) {
testLockN(t, 1000)
}
func Test10000Lock(t *testing.T) {
testLockN(t, 10000)
}
func testLockN(t *testing.T, n int) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
return l return l
} }
reqs := make([]NodeID, 1000) reqs := make([]NodeID, n)
for i := range(reqs) { for i := range(reqs) {
new_lockable := NewLockable() new_lockable := NewLockable()
reqs[i] = new_lockable.ID reqs[i] = new_lockable.ID
} }
ctx.Log.Logf("test", "CREATED_1000") ctx.Log.Logf("test", "CREATED_%d", n)
listener := NewListenerExt(5000) listener := NewListenerExt(50000)
node, err := NewNode(ctx, nil, "Lockable", 5000, listener, NewLockableExt(reqs)) node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "CREATED_LISTENER") ctx.Log.Logf("test", "CREATED_LISTENER")
@ -75,15 +91,15 @@ func Test1000Lock(t *testing.T) {
t.Fatalf("Unexpected response to lock - %s", resp) t.Fatalf("Unexpected response to lock - %s", resp)
} }
ctx.Log.Logf("test", "LOCKED_1000") ctx.Log.Logf("test", "LOCKED_%d", n)
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
ctx := logTestContext(t, []string{"test", "lockable"}) ctx := logTestContext(t, []string{"test", "lockable"})
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
listener := NewListenerExt(1000) listener := NewListenerExt(10000)
l, err := NewNode(ctx, nil, "Lockable", 10, listener, NewLockableExt(reqs)) l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err) fatalErr(t, err)
return l, listener return l, listener
} }
@ -102,25 +118,31 @@ func TestLock(t *testing.T) {
ctx.Log.Logf("test", "l4: %s", l4.ID) ctx.Log.Logf("test", "l4: %s", l4.ID)
ctx.Log.Logf("test", "l5: %s", l5.ID) ctx.Log.Logf("test", "l5: %s", l5.ID)
ctx.Log.Logf("test", "locking l0")
id_1, err := LockLockable(ctx, l0) id_1, err := LockLockable(ctx, l0)
ctx.Log.Logf("test", "ID_1: %s", id_1)
fatalErr(t, err) fatalErr(t, err)
_, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1) response, _, err := WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "l0 lock: %+v", response)
ctx.Log.Logf("test", "locking l1")
id_2, err := LockLockable(ctx, l1) id_2, err := LockLockable(ctx, l1)
fatalErr(t, err) fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*100, id_2) response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "l1 lock: %+v", response)
ctx.Log.Logf("test", "unlocking l0")
id_3, err := UnlockLockable(ctx, l0) id_3, err := UnlockLockable(ctx, l0)
fatalErr(t, err) fatalErr(t, err)
_, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3) response, _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "l0 unlock: %+v", response)
ctx.Log.Logf("test", "locking l1")
id_4, err := LockLockable(ctx, l1) id_4, err := LockLockable(ctx, l1)
fatalErr(t, err) fatalErr(t, err)
response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "l1 lock: %+v", response)
} }

@ -50,7 +50,7 @@ func (logger * ConsoleLogger) SetComponents(components []string) error {
return false return false
} }
for c, _ := range(logger.loggers) { for c := range(logger.loggers) {
if component_enabled(c) == false { if component_enabled(c) == false {
delete(logger.loggers, c) delete(logger.loggers, c)
} }

@ -1,11 +1,68 @@
package graphvent package graphvent
type SendMsg struct { type Message struct {
Dest NodeID Node NodeID
Signal Signal Signal Signal
} }
type RecvMsg struct { type MessageQueue struct {
Source NodeID out chan<- Message
Signal Signal in <-chan Message
buffer []Message
write_cursor int
read_cursor int
}
func (queue *MessageQueue) ProcessIncoming(message Message) {
if (queue.write_cursor + 1) == queue.read_cursor || ((queue.write_cursor + 1) == len(queue.buffer) && queue.read_cursor == 0) {
new_buffer := make([]Message, len(queue.buffer) * 2)
copy(new_buffer, queue.buffer[queue.read_cursor:])
first_chunk := len(queue.buffer) - queue.read_cursor
copy(new_buffer[first_chunk:], queue.buffer[0:queue.write_cursor])
queue.write_cursor = len(queue.buffer) - 1
queue.read_cursor = 0
queue.buffer = new_buffer
}
queue.buffer[queue.write_cursor] = message
queue.write_cursor += 1
if queue.write_cursor >= len(queue.buffer) {
queue.write_cursor = 0
}
}
func NewMessageQueue(initial int) (chan<- Message, <-chan Message) {
in := make(chan Message, 0)
out := make(chan Message, 0)
queue := MessageQueue{
out: out,
in: in,
buffer: make([]Message, initial),
write_cursor: 0,
read_cursor: 0,
}
go func(queue *MessageQueue) {
for true {
if queue.write_cursor != queue.read_cursor {
select {
case incoming := <-queue.in:
queue.ProcessIncoming(incoming)
case queue.out <- queue.buffer[queue.read_cursor]:
queue.read_cursor += 1
if queue.read_cursor >= len(queue.buffer) {
queue.read_cursor = 0
}
}
} else {
message := <-queue.in
queue.ProcessIncoming(message)
}
}
}(&queue)
return in, out
} }

@ -0,0 +1,35 @@
package graphvent
import (
"encoding/binary"
"testing"
)
func sendBatch(start, end uint64, in chan<- Message) {
for i := start; i <= end; i++ {
var id NodeID
binary.BigEndian.PutUint64(id[:], i)
in <- Message{id, nil}
}
}
func TestMessageQueue(t *testing.T) {
in, out := NewMessageQueue(10)
for i := uint64(0); i < 1000; i++ {
go sendBatch(1000*i, (1000*(i+1))-1, in)
}
seen := map[NodeID]any{}
for i := uint64(0); i < 1000*1000; i++ {
read := <-out
_, already_seen := seen[read.Node]
if already_seen {
t.Fatalf("Signal %d had duplicate NodeID %s", i, read.Node)
} else {
seen[read.Node] = nil
}
}
t.Logf("Processed 1M signals through queue")
}

@ -1,17 +1,15 @@
package graphvent package graphvent
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/sha512"
"crypto/sha512" "fmt"
"encoding/binary" "reflect"
"fmt" "sync/atomic"
"reflect" "time"
"sync/atomic"
"time" _ "github.com/dgraph-io/badger/v3"
"github.com/google/uuid"
_ "github.com/dgraph-io/badger/v3"
"github.com/google/uuid"
) )
var ( var (
@ -25,6 +23,10 @@ type NodeID uuid.UUID
func (id NodeID) MarshalBinary() ([]byte, error) { func (id NodeID) MarshalBinary() ([]byte, error) {
return (uuid.UUID)(id).MarshalBinary() return (uuid.UUID)(id).MarshalBinary()
} }
func (id *NodeID) UnmarshalBinary(data []byte) error {
return (*uuid.UUID)(id).UnmarshalBinary(data)
}
func (id NodeID) String() string { func (id NodeID) String() string {
return (uuid.UUID)(id).String() return (uuid.UUID)(id).String()
} }
@ -67,25 +69,27 @@ func (q QueuedSignal) String() string {
return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time) return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time)
} }
// Default message channel size for nodes type WaitMap map[uuid.UUID]NodeID
const NODE_INITIAL_QUEUE_SIZE = 2
// Nodes represent a group of extensions that can be collectively addressed // Nodes represent a group of extensions that can be collectively addressed
type Node struct { type Node struct {
Key ed25519.PrivateKey `gv:"key"` Key ed25519.PrivateKey `gv:"key"`
ID NodeID ID NodeID
Type NodeType `gv:"type"` Type NodeType `gv:"type"`
// TODO: move each extension to it's own db key, and extend changes to notify which extension was changed
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
// Channel for this node to receive messages from the Context // Channel for this node to receive messages from the Context
MsgChan chan RecvMsg SendChan chan<- Message
// Size of MsgChan RecvChan <-chan Message
BufferSize uint32 `gv:"buffer_size"`
// Channel for this node to process delayed signals // Channel for this node to process delayed signals
TimeoutChan <-chan time.Time TimeoutChan <-chan time.Time
Active atomic.Bool Active atomic.Bool
// TODO: enhance WriteNode to write SignalQueue to a different key, and use writeSignalQueue to decide whether or not to update it
writeSignalQueue bool writeSignalQueue bool
SignalQueue []QueuedSignal SignalQueue []QueuedSignal
NextSignal *QueuedSignal NextSignal *QueuedSignal
@ -97,59 +101,11 @@ func (node *Node) PostDeserialize(ctx *Context) error {
public := node.Key.Public().(ed25519.PublicKey) public := node.Key.Public().(ed25519.PublicKey)
node.ID = KeyID(public) node.ID = KeyID(public)
node.MsgChan = make(chan RecvMsg, node.BufferSize) node.SendChan, node.RecvChan = NewMessageQueue(NODE_INITIAL_QUEUE_SIZE)
return nil return nil
} }
type WaitReason string
type WaitInfo struct {
Destination NodeID `gv:"destination" node:"Base"`
Timeout uuid.UUID `gv:"timeout"`
Reason WaitReason `gv:"reason"`
}
type WaitMap map[uuid.UUID]WaitInfo
// Removes a signal from the wait_map and dequeue the associated timeout signal
// Returns the data, and whether or not the ID was found in the wait_map
func (node *Node) ProcessResponse(wait_map WaitMap, response ResponseSignal) (WaitInfo, bool) {
wait_info, is_processed := wait_map[response.ResponseID()]
if is_processed == true {
delete(wait_map, response.ResponseID())
if response.ID() != wait_info.Timeout {
node.DequeueSignal(wait_info.Timeout)
}
return wait_info, true
}
return WaitInfo{}, false
}
func (node *Node) NewTimeout(reason WaitReason, dest NodeID, timeout time.Duration) (WaitInfo, uuid.UUID) {
id := uuid.New()
timeout_signal := NewTimeoutSignal(id)
node.QueueSignal(time.Now().Add(timeout), timeout_signal)
return WaitInfo{
Destination: dest,
Timeout: timeout_signal.Id,
Reason: reason,
}, id
}
// Creates a timeout signal for signal, queues it for the node at the timeout, and returns the WaitInfo
func (node *Node) QueueTimeout(reason WaitReason, dest NodeID, signal Signal, timeout time.Duration) WaitInfo {
timeout_signal := NewTimeoutSignal(signal.ID())
node.QueueSignal(time.Now().Add(timeout), timeout_signal)
return WaitInfo{
Destination: dest,
Timeout: timeout_signal.Id,
Reason: reason,
}
}
func (node *Node) QueueSignal(time time.Time, signal Signal) { func (node *Node) QueueSignal(time time.Time, signal Signal) {
node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time}) node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time})
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
@ -187,18 +143,23 @@ func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) {
} }
if soonest_signal != nil { if soonest_signal != nil {
return soonest_signal, time.After(time.Until(soonest_signal.Time)) if time.Now().Compare(soonest_time) == -1 {
return soonest_signal, time.After(time.Until(soonest_signal.Time))
} else {
c := make(chan time.Time, 1)
c <- soonest_time
return soonest_signal, c
}
} else { } else {
return nil, nil return nil, nil
} }
} }
func runNode(ctx *Context, node *Node) { func runNode(ctx *Context, node *Node, status chan string, control chan string) {
ctx.Log.Logf("node", "RUN_START: %s", node.ID) ctx.Log.Logf("node", "RUN_START: %s", node.ID)
err := nodeLoop(ctx, node) err := nodeLoop(ctx, node, status, control)
if err != nil { if err != nil {
ctx.Log.Logf("node", "%s runNode err %s", node.ID, err) ctx.Log.Logf("node", "%s runNode err %s", node.ID, err)
panic(err)
} }
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
} }
@ -214,43 +175,74 @@ func (err StringError) MarshalBinary() ([]byte, error) {
return []byte(string(err)), nil return []byte(string(err)), nil
} }
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]any { func (node *Node) ReadFields(ctx *Context, fields []string)map[string]any {
ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID) ctx.Log.Logf("read_field", "Reading %+v on %+v", fields, node.ID)
exts := map[ExtType]map[string]any{} values := map[string]any{}
for ext_type, field_reqs := range(reqs) {
ext_info, ext_known := ctx.Extensions[ext_type] node_info := ctx.NodeTypes[node.Type]
if ext_known {
fields := map[string]any{} for _, field_name := range(fields) {
for _, req := range(field_reqs) { field_info, mapped := node_info.Fields[field_name]
ext, exists := node.Extensions[ext_type] if mapped {
if exists == false { ext := node.Extensions[field_info.Extension]
fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type) values[field_name] = reflect.ValueOf(ext).Elem().FieldByIndex(field_info.Index).Interface()
} else { } else {
fields[req] = reflect.ValueOf(ext).Elem().FieldByIndex(ext_info.Fields[req].Index).Interface() values[field_name] = fmt.Errorf("NodeType %s has no field %s", node.Type, field_name)
}
}
exts[ext_type] = fields
} }
} }
return exts
return values
} }
// Main Loop for nodes // Main Loop for nodes
func nodeLoop(ctx *Context, node *Node) error { func nodeLoop(ctx *Context, node *Node, status chan string, control chan string) error {
started := node.Active.CompareAndSwap(false, true) is_started := node.Active.CompareAndSwap(false, true)
if started == false { if is_started == false {
return fmt.Errorf("%s is already started, will not start again", node.ID) return fmt.Errorf("%s is already started, will not start again", node.ID)
} else {
ctx.Log.Logf("node", "Set %s active", node.ID)
}
ctx.Log.Logf("node_ext", "Loading extensions for %s", node.ID)
for _, extension := range(node.Extensions) {
ctx.Log.Logf("node_ext", "Loading extension %s for %s", reflect.TypeOf(extension), node.ID)
err := extension.Load(ctx, node)
if err != nil {
ctx.Log.Logf("node_ext", "Failed to load extension %s on node %s", reflect.TypeOf(extension), node.ID)
node.Active.Store(false)
return err
} else {
ctx.Log.Logf("node_ext", "Loaded extension %s on node %s", reflect.TypeOf(extension), node.ID)
}
} }
run := true ctx.Log.Logf("node_ext", "Loaded extensions for %s", node.ID)
for run == true {
status <- "active"
running := true
for running {
var signal Signal var signal Signal
var source NodeID var source NodeID
select {
case msg := <- node.MsgChan:
signal = msg.Signal
source = msg.Source
select {
case command := <-control:
switch command {
case "stop":
running = false
case "pause":
status <- "paused"
command := <- control
switch command {
case "resume":
status <- "resumed"
case "stop":
running = false
}
default:
ctx.Log.Logf("node", "Unknown control command %s", command)
}
case <-node.TimeoutChan: case <-node.TimeoutChan:
signal = node.NextSignal.Signal signal = node.NextSignal.Signal
source = node.ID source = node.ID
@ -279,15 +271,17 @@ func nodeLoop(ctx *Context, node *Node) error {
} else { } else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time) ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time)
} }
} case msg := <- node.RecvChan:
signal = msg.Signal
source = msg.Node
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) }
switch sig := signal.(type) { switch sig := signal.(type) {
case *ReadSignal: case *ReadSignal:
result := node.ReadFields(ctx, sig.Extensions) result := node.ReadFields(ctx, sig.Fields)
msgs := []SendMsg{} msgs := []Message{}
msgs = append(msgs, SendMsg{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)}) msgs = append(msgs, Message{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)})
ctx.Send(node, msgs) ctx.Send(node, msgs)
default: default:
@ -303,31 +297,45 @@ func nodeLoop(ctx *Context, node *Node) error {
if stopped == false { if stopped == false {
panic("BAD_STATE: stopping already stopped node") panic("BAD_STATE: stopping already stopped node")
} }
for _, extension := range(node.Extensions) {
extension.Unload(ctx, node)
}
status <- "stopped"
return nil return nil
} }
func (node *Node) Unload(ctx *Context) error { func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error {
if node.Active.Load() { node_info, exists := ctx.NodeTypes[node.Type]
for _, extension := range(node.Extensions) { if exists == false {
extension.Unload(ctx, node) return fmt.Errorf("Node type not in context, can't map changes to field names")
} else {
fields := []string{}
for ext_type, ext_changes := range(changes) {
ext_map, ext_mapped := node_info.ReverseFields[ext_type]
if ext_mapped {
for _, ext_tag := range(ext_changes) {
field_name, tag_mapped := ext_map[ext_tag]
if tag_mapped {
fields = append(fields, field_name)
}
}
}
}
ctx.Log.Logf("changes", "Changes to queue from %+v: %+v", node_info.ReverseFields, fields)
if len(fields) > 0 {
node.QueueSignal(time.Time{}, NewStatusSignal(node.ID, fields))
} }
return nil return nil
} else {
return fmt.Errorf("Node not active")
} }
} }
func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error {
node.QueueSignal(time.Now(), NewStatusSignal(node.ID, changes))
return nil
}
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal) messages := []Message{}
messages := []SendMsg{}
changes := map[ExtType]Changes{} changes := map[ExtType]Changes{}
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
ext_messages, ext_changes := ext.Process(ctx, node, source, signal) ext_messages, ext_changes := ext.Process(ctx, node, source, signal)
if len(ext_messages) != 0 { if len(ext_messages) != 0 {
messages = append(messages, ext_messages...) messages = append(messages, ext_messages...)
@ -336,7 +344,6 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
changes[ext_type] = ext_changes changes[ext_type] = ext_changes
} }
} }
ctx.Log.Logf("changes", "Changes for %s after %+v - %+v", node.ID, reflect.TypeOf(signal), changes)
if len(messages) != 0 { if len(messages) != 0 {
send_err := ctx.Send(node, messages) send_err := ctx.Send(node, messages)
@ -346,11 +353,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
} }
if len(changes) != 0 { if len(changes) != 0 {
write_err := WriteNodeChanges(ctx, node, changes) ctx.Log.Logf("changes", "Changes to %s from %+v: %+v", node.ID, signal, changes)
if write_err != nil {
return write_err
}
status_err := node.QueueChanges(ctx, changes) status_err := node.QueueChanges(ctx, changes)
if status_err != nil { if status_err != nil {
return status_err return status_err
@ -396,85 +399,3 @@ func KeyID(pub ed25519.PublicKey) NodeID {
id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID(id) return NodeID(id)
} }
// Create a new node in memory and start it's event loop
func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, extensions ...Extension) (*Node, error) {
node_type, known_type := ctx.NodeTypes[type_name]
if known_type == false {
return nil, fmt.Errorf("%s is not a known node type", type_name)
}
var err error
var public ed25519.PublicKey
if key == nil {
public, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
} else {
public = key.Public().(ed25519.PublicKey)
}
id := KeyID(public)
_, exists := ctx.Node(id)
if exists == true {
return nil, fmt.Errorf("Attempted to create an existing node")
}
ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) {
ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext).Elem()]
if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext)))
}
_, exists = ext_map[ext_type.ExtType]
if exists == true {
return nil, fmt.Errorf("Cannot add the same extension to a node twice")
}
ext_map[ext_type.ExtType] = ext
}
for _, required_ext := range(node_type.Extensions) {
_, exists := ext_map[required_ext]
if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
}
}
node := &Node{
Key: key,
ID: id,
Type: node_type.NodeType,
Extensions: ext_map,
MsgChan: make(chan RecvMsg, buffer_size),
BufferSize: buffer_size,
SignalQueue: []QueuedSignal{},
writeSignalQueue: false,
}
err = WriteNodeInit(ctx, node)
if err != nil {
return nil, err
}
// Load each extension before starting the main loop
for _, extension := range(node.Extensions) {
err := extension.Load(ctx, node)
if err != nil {
return nil, err
}
}
ctx.AddNode(id, node)
go runNode(ctx, node)
return node, nil
}
var extension_suffix = []byte{0xEE, 0xFF, 0xEE, 0xFF}
var signal_queue_suffix = []byte{0xAB, 0xBA, 0xAB, 0xBA}
func ExtTypeSuffix(ext_type ExtType) []byte {
ret := make([]byte, 12)
copy(ret[0:4], extension_suffix)
binary.BigEndian.PutUint64(ret[4:], uint64(ext_type))
return ret
}

@ -5,29 +5,19 @@ import (
"time" "time"
"crypto/rand" "crypto/rand"
"crypto/ed25519" "crypto/ed25519"
"slices"
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"node", "db"}) ctx := logTestContext(t, []string{"test", "node", "db"})
node_listener := NewListenerExt(10) node_listener := NewListenerExt(10)
node, err := NewNode(ctx, nil, "Base", 10, NewLockableExt(nil), node_listener) node, err := ctx.NewNode(nil, "Node", NewLockableExt(nil), node_listener)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { err = ctx.Stop()
gql_changes, has_gql := sig.Changes[ExtTypeFor[GQLExt]()]
if has_gql == true {
return slices.Contains(gql_changes, "state") && sig.Source == node.ID
}
return false
})
err = ctx.Unload(node.ID)
fatalErr(t, err) fatalErr(t, err)
ctx.nodeMap = map[NodeID]*Node{} _, err = ctx.GetNode(node.ID)
_, err = ctx.getNode(node.ID)
fatalErr(t, err) fatalErr(t, err)
} }
@ -46,16 +36,14 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2, err := NewNode(ctx, n2_key, "Base", 10, n2_listener) n2, err := ctx.NewNode(n2_key, "Node", n2_listener)
fatalErr(t, err) fatalErr(t, err)
n1, err := NewNode(ctx, n1_key, "Base", 10, NewListenerExt(10)) n1, err := ctx.NewNode(n1_key, "Node", NewListenerExt(10))
fatalErr(t, err) fatalErr(t, err)
read_sig := NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal([]string{"buffer"})
ExtTypeFor[ListenerExt](): {"buffer"}, msgs := []Message{{n1.ID, read_sig}}
})
msgs := []SendMsg{{n1.ID, read_sig}}
err = ctx.Send(n2, msgs) err = ctx.Send(n2, msgs)
fatalErr(t, err) fatalErr(t, err)

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"math" "math"
"slices"
) )
type SerializedType uint64 type SerializedType uint64
@ -39,14 +38,8 @@ func (t FieldTag) String() string {
return fmt.Sprintf("0x%x", uint64(t)) return fmt.Sprintf("0x%x", uint64(t))
} }
func NodeTypeFor(extensions []ExtType) NodeType { func NodeTypeFor(name string) NodeType {
digest := []byte("GRAPHVENT_NODE - ") digest := []byte("GRAPHVENT_NODE - " + name)
slices.Sort(extensions)
for _, ext := range(extensions) {
digest = binary.BigEndian.AppendUint64(digest, uint64(ext))
}
hash := sha512.Sum512(digest) hash := sha512.Sum512(digest)
return NodeType(binary.BigEndian.Uint64(hash[0:8])) return NodeType(binary.BigEndian.Uint64(hash[0:8]))
@ -66,6 +59,10 @@ func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType {
return ExtType(SerializedTypeFor[E]()) return ExtType(SerializedTypeFor[E]())
} }
func ExtTypeOf(t reflect.Type) ExtType {
return ExtType(SerializeType(t.Elem()))
}
func SignalTypeFor[S Signal]() SignalType { func SignalTypeFor[S Signal]() SignalType {
return SignalType(SerializedTypeFor[S]()) return SignalType(SerializedTypeFor[S]())
} }
@ -80,49 +77,59 @@ func GetFieldTag(tag string) FieldTag {
return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag)) return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag))
} }
func TypeStack(ctx *Context, t reflect.Type) ([]byte, error) { func TypeStack(ctx *Context, t reflect.Type, data []byte) (int, error) {
info, registered := ctx.TypeTypes[t] info, registered := ctx.Types[t]
if registered { if registered {
return binary.BigEndian.AppendUint64(nil, uint64(info.Serialized)), nil binary.BigEndian.PutUint64(data, uint64(info.Serialized))
return 8, nil
} else { } else {
switch t.Kind() { switch t.Kind() {
case reflect.Map: case reflect.Map:
key_stack, err := TypeStack(ctx, t.Key()) binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Map)))
key_written, err := TypeStack(ctx, t.Key(), data[8:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
elem_stack, err := TypeStack(ctx, t.Elem()) elem_written, err := TypeStack(ctx, t.Elem(), data[8 + key_written:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Map))), append(key_stack, elem_stack...)...), nil return 8 + key_written + elem_written, nil
case reflect.Pointer: case reflect.Pointer:
elem_stack, err := TypeStack(ctx, t.Elem()) binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Pointer)))
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Pointer))), elem_stack...), nil return 8 + elem_written, nil
case reflect.Slice: case reflect.Slice:
elem_stack, err := TypeStack(ctx, t.Elem()) binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Slice)))
elem_written, err := TypeStack(ctx, t.Elem(), data[8:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Slice))), elem_stack...), nil return 8 + elem_written, nil
case reflect.Array: case reflect.Array:
elem_stack, err := TypeStack(ctx, t.Elem()) binary.BigEndian.PutUint64(data, uint64(SerializeType(reflect.Array)))
binary.BigEndian.PutUint64(data[8:], uint64(t.Len()))
elem_written, err := TypeStack(ctx, t.Elem(), data[16:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
stack := binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Array))) return 16 + elem_written, nil
stack = binary.BigEndian.AppendUint64(stack, uint64(t.Len()))
return append(stack, elem_stack...), nil
default: default:
return nil, fmt.Errorf("Hit %s, which is not a registered type", t.String()) return 0, fmt.Errorf("Hit %s, which is not a registered type", t.String())
} }
} }
} }
@ -131,7 +138,7 @@ func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
first_bytes, left := split(stack, 8) first_bytes, left := split(stack, 8)
first := SerializedType(binary.BigEndian.Uint64(first_bytes)) first := SerializedType(binary.BigEndian.Uint64(first_bytes))
info, registered := ctx.TypeMap[first] info, registered := ctx.TypesReverse[first]
if registered { if registered {
return info.Reflect, left, nil return info.Reflect, left, nil
} else { } else {
@ -176,18 +183,18 @@ func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) {
} }
} }
func Serialize[T any](ctx *Context, value T) ([]byte, error) { func Serialize[T any](ctx *Context, value T, data []byte) (int, error) {
return serializeValue(ctx, reflect.ValueOf(&value).Elem()) return SerializeValue(ctx, reflect.ValueOf(&value).Elem(), data)
} }
func Deserialize[T any](ctx *Context, data []byte) (T, error) { func Deserialize[T any](ctx *Context, data []byte) (T, error) {
reflect_type := reflect.TypeFor[T]() reflect_type := reflect.TypeFor[T]()
var zero T var zero T
value, left, err := deserializeValue(ctx, data, reflect_type) value, left, err := DeserializeValue(ctx, data, reflect_type)
if err != nil { if err != nil {
return zero, err return zero, err
} else if len(left) != 0 { } else if len(left) != 0 {
return zero, fmt.Errorf("%d bytes left after deserializing %+v", len(left), value) return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
} else if value.Type() != reflect_type { } else if value.Type() != reflect_type {
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type) return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
} }
@ -195,156 +202,318 @@ func Deserialize[T any](ctx *Context, data []byte) (T, error) {
return value.Interface().(T), nil return value.Interface().(T), nil
} }
func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { func SerializedSize(ctx *Context, value reflect.Value) (int, error) {
var serialize SerializeFn = nil var sizefn SerializedSizeFn = nil
info, registered := ctx.TypeTypes[value.Type()] info, registered := ctx.Types[value.Type()]
if registered { if registered {
serialize = info.Serialize sizefn = info.SerializedSize
} }
if serialize == nil { if sizefn == nil {
switch value.Type().Kind() { switch value.Type().Kind() {
case reflect.Bool: case reflect.Bool:
if value.Bool() { return 1, nil
return []byte{0xFF}, nil
} else {
return []byte{0x00}, nil
}
case reflect.Int8: case reflect.Int8:
return []byte{byte(value.Int())}, nil return 1, nil
case reflect.Int16: case reflect.Int16:
return binary.BigEndian.AppendUint16(nil, uint16(value.Int())), nil return 2, nil
case reflect.Int32: case reflect.Int32:
return binary.BigEndian.AppendUint32(nil, uint32(value.Int())), nil return 4, nil
case reflect.Int64: case reflect.Int64:
fallthrough fallthrough
case reflect.Int: case reflect.Int:
return binary.BigEndian.AppendUint64(nil, uint64(value.Int())), nil return 8, nil
case reflect.Uint8: case reflect.Uint8:
return []byte{byte(value.Uint())}, nil return 1, nil
case reflect.Uint16: case reflect.Uint16:
return binary.BigEndian.AppendUint16(nil, uint16(value.Uint())), nil return 2, nil
case reflect.Uint32: case reflect.Uint32:
return binary.BigEndian.AppendUint32(nil, uint32(value.Uint())), nil return 4, nil
case reflect.Uint64: case reflect.Uint64:
fallthrough fallthrough
case reflect.Uint: case reflect.Uint:
return binary.BigEndian.AppendUint64(nil, value.Uint()), nil return 8, nil
case reflect.Float32: case reflect.Float32:
return binary.BigEndian.AppendUint32(nil, math.Float32bits(float32(value.Float()))), nil return 4, nil
case reflect.Float64: case reflect.Float64:
return binary.BigEndian.AppendUint64(nil, math.Float64bits(value.Float())), nil return 8, nil
case reflect.String: case reflect.String:
len_bytes := make([]byte, 8) return 8 + value.Len(), nil
binary.BigEndian.PutUint64(len_bytes, uint64(value.Len()))
return append(len_bytes, []byte(value.String())...), nil
case reflect.Pointer: case reflect.Pointer:
if value.IsNil() { if value.IsNil() {
return []byte{0x00}, nil return 1, nil
} else { } else {
elem, err := serializeValue(ctx, value.Elem()) elem_len, err := SerializedSize(ctx, value.Elem())
if err != nil { if err != nil {
return nil, err return 0, err
} else {
return 1 + elem_len, nil
} }
return append([]byte{0x01}, elem...), nil
} }
case reflect.Slice: case reflect.Slice:
if value.IsNil() { if value.IsNil() {
return []byte{0x00}, nil return 1, nil
} else { } else {
len_bytes := make([]byte, 8) elem_total := 0
binary.BigEndian.PutUint64(len_bytes, uint64(value.Len()))
data := []byte{}
for i := 0; i < value.Len(); i++ { for i := 0; i < value.Len(); i++ {
elem, err := serializeValue(ctx, value.Index(i)) elem_len, err := SerializedSize(ctx, value.Index(i))
if err != nil { if err != nil {
return nil, err return 0, err
} }
elem_total += elem_len
data = append(data, elem...)
} }
return 9 + elem_total, nil
return append(len_bytes, data...), nil
} }
case reflect.Array: case reflect.Array:
data := []byte{} total := 0
for i := 0; i < value.Len(); i++ { for i := 0; i < value.Len(); i++ {
elem, err := serializeValue(ctx, value.Index(i)) elem_len, err := SerializedSize(ctx, value.Index(i))
if err != nil { if err != nil {
return nil, err return 0, err
} }
total += elem_len
data = append(data, elem...)
} }
return data, nil return total, nil
case reflect.Map: case reflect.Map:
len_bytes := make([]byte, 8) if value.IsNil() {
binary.BigEndian.PutUint64(len_bytes, uint64(value.Len())) return 1, nil
} else {
key := reflect.New(value.Type().Key()).Elem()
val := reflect.New(value.Type().Elem()).Elem()
iter := value.MapRange()
total := 0
for iter.Next() {
key.SetIterKey(iter)
k, err := SerializedSize(ctx, key)
if err != nil {
return 0, err
}
total += k
val.SetIterValue(iter)
v, err := SerializedSize(ctx, val)
if err != nil {
return 0, err
}
total += v
}
return 9 + total, nil
}
case reflect.Struct:
if registered == false {
return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type())
} else {
field_total := 0
for _, field_info := range(info.Fields) {
field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index))
if err != nil {
return 0, err
}
field_total += 8
field_total += field_size
}
return 8 + field_total, nil
}
case reflect.Interface:
// TODO get size of TypeStack instead of just using 128
elem_size, err := SerializedSize(ctx, value.Elem())
if err != nil {
return 0, err
}
return 128 + elem_size, nil
default:
return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
}
} else {
return sizefn(ctx, value)
}
}
func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) {
var serialize SerializeFn = nil
info, registered := ctx.Types[value.Type()]
if registered {
serialize = info.Serialize
}
if serialize == nil {
switch value.Type().Kind() {
case reflect.Bool:
if value.Bool() {
data[0] = 0xFF
} else {
data[0] = 0x00
}
return 1, nil
case reflect.Int8:
data[0] = byte(value.Int())
return 1, nil
case reflect.Int16:
binary.BigEndian.PutUint16(data, uint16(value.Int()))
return 2, nil
case reflect.Int32:
binary.BigEndian.PutUint32(data, uint32(value.Int()))
return 4, nil
case reflect.Int64:
fallthrough
case reflect.Int:
binary.BigEndian.PutUint64(data, uint64(value.Int()))
return 8, nil
case reflect.Uint8:
data[0] = byte(value.Uint())
return 1, nil
case reflect.Uint16:
binary.BigEndian.PutUint16(data, uint16(value.Uint()))
return 2, nil
case reflect.Uint32:
binary.BigEndian.PutUint32(data, uint32(value.Uint()))
return 4, nil
case reflect.Uint64:
fallthrough
case reflect.Uint:
binary.BigEndian.PutUint64(data, value.Uint())
return 8, nil
data := []byte{} case reflect.Float32:
iter := value.MapRange() binary.BigEndian.PutUint32(data, math.Float32bits(float32(value.Float())))
for iter.Next() { return 4, nil
k, err := serializeValue(ctx, iter.Key()) case reflect.Float64:
binary.BigEndian.PutUint64(data, math.Float64bits(value.Float()))
return 8, nil
case reflect.String:
binary.BigEndian.PutUint64(data, uint64(value.Len()))
copy(data[8:], []byte(value.String()))
return 8 + value.Len(), nil
case reflect.Pointer:
if value.IsNil() {
data[0] = 0x00
return 1, nil
} else {
data[0] = 0x01
written, err := SerializeValue(ctx, value.Elem(), data[1:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
return 1 + written, nil
}
data = append(data, k...) case reflect.Slice:
if value.IsNil() {
data[0] = 0x00
return 8, nil
} else {
data[0] = 0x01
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
total_written := 0
for i := 0; i < value.Len(); i++ {
written, err := SerializeValue(ctx, value.Index(i), data[9+total_written:])
if err != nil {
return 0, err
}
total_written += written
}
return 9 + total_written, nil
}
v, err := serializeValue(ctx, iter.Value()) case reflect.Array:
total_written := 0
for i := 0; i < value.Len(); i++ {
written, err := SerializeValue(ctx, value.Index(i), data[total_written:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
total_written += written
}
return total_written, nil
case reflect.Map:
if value.IsNil() {
data[0] = 0x00
return 1, nil
} else {
data[0] = 0x01
binary.BigEndian.PutUint64(data[1:], uint64(value.Len()))
key := reflect.New(value.Type().Key()).Elem()
val := reflect.New(value.Type().Elem()).Elem()
iter := value.MapRange()
total_written := 0
for iter.Next() {
key.SetIterKey(iter)
val.SetIterValue(iter)
k, err := SerializeValue(ctx, key, data[9+total_written:])
if err != nil {
return 0, err
}
total_written += k
data = append(data, v...) v, err := SerializeValue(ctx, val, data[9+total_written:])
if err != nil {
return 0, err
}
total_written += v
}
return 9 + total_written, nil
} }
return append(len_bytes, data...), nil
case reflect.Struct: case reflect.Struct:
if registered == false { if registered == false {
return nil, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type()) return 0, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type())
} else { } else {
data := binary.BigEndian.AppendUint64(nil, uint64(len(info.Fields))) binary.BigEndian.PutUint64(data, uint64(len(info.Fields)))
total_written := 0
for field_tag, field_info := range(info.Fields) { for field_tag, field_info := range(info.Fields) {
data = append(data, binary.BigEndian.AppendUint64(nil, uint64(field_tag))...) binary.BigEndian.PutUint64(data[8+total_written:], uint64(field_tag))
field_bytes, err := serializeValue(ctx, value.FieldByIndex(field_info.Index)) total_written += 8
written, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index), data[8+total_written:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
total_written += written
data = append(data, field_bytes...)
} }
return data, nil return 8 + total_written, nil
} }
case reflect.Interface: case reflect.Interface:
data, err := TypeStack(ctx, value.Elem().Type()) type_written, err := TypeStack(ctx, value.Elem().Type(), data)
val_data, err := serializeValue(ctx, value.Elem()) elem_written, err := SerializeValue(ctx, value.Elem(), data[type_written:])
if err != nil { if err != nil {
return nil, err return 0, err
} }
data = append(data, val_data...) return type_written + elem_written, nil
return data, nil
default: default:
return nil, fmt.Errorf("Don't know how to serialize %s", value.Type()) return 0, fmt.Errorf("Don't know how to serialize %s", value.Type())
} }
} else { } else {
return serialize(ctx, value) return serialize(ctx, value, data)
} }
} }
@ -352,10 +521,10 @@ func split(data []byte, n int) ([]byte, []byte) {
return data[:n], data[n:] return data[:n], data[n:]
} }
func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) { func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) {
var deserialize DeserializeFn = nil var deserialize DeserializeFn = nil
info, registered := ctx.TypeTypes[t] info, registered := ctx.Types[t]
if registered { if registered {
deserialize = info.Deserialize deserialize = info.Deserialize
} }
@ -439,7 +608,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
value.SetZero() value.SetZero()
return value, after_flags, nil return value, after_flags, nil
} else { } else {
elem_value, after_elem, err := deserializeValue(ctx, after_flags, t.Elem()) elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem())
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }
@ -448,19 +617,25 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
} }
case reflect.Slice: case reflect.Slice:
len_bytes, left := split(data, 8) nil_byte := data[0]
length := int(binary.BigEndian.Uint64(len_bytes)) data = data[1:]
value := reflect.MakeSlice(t, length, length) if nil_byte == 0x00 {
for i := 0; i < length; i++ { return reflect.New(t).Elem(), data, nil
var elem_value reflect.Value } else {
var err error len_bytes, left := split(data, 8)
elem_value, left, err = deserializeValue(ctx, left, t.Elem()) length := int(binary.BigEndian.Uint64(len_bytes))
if err != nil { value := reflect.MakeSlice(t, length, length)
return reflect.Value{}, nil, err for i := 0; i < length; i++ {
var elem_value reflect.Value
var err error
elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
value.Index(i).Set(elem_value)
} }
value.Index(i).Set(elem_value) return value, left, nil
} }
return value, left, nil
case reflect.Array: case reflect.Array:
value := reflect.New(t).Elem() value := reflect.New(t).Elem()
@ -468,7 +643,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
for i := 0; i < t.Len(); i++ { for i := 0; i < t.Len(); i++ {
var elem_value reflect.Value var elem_value reflect.Value
var err error var err error
elem_value, left, err = deserializeValue(ctx, left, t.Elem()) elem_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }
@ -477,33 +652,38 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
return value, left, nil return value, left, nil
case reflect.Map: case reflect.Map:
len_bytes, left := split(data, 8) flags, after_flags := split(data, 1)
length := int(binary.BigEndian.Uint64(len_bytes)) if flags[0] == 0x00 {
return reflect.New(t).Elem(), after_flags, nil
} else {
len_bytes, left := split(after_flags, 8)
length := int(binary.BigEndian.Uint64(len_bytes))
value := reflect.MakeMapWithSize(t, length) value := reflect.MakeMapWithSize(t, length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
var key_value reflect.Value var key_value reflect.Value
var val_value reflect.Value var val_value reflect.Value
var err error var err error
key_value, left, err = deserializeValue(ctx, left, t.Key()) key_value, left, err = DeserializeValue(ctx, left, t.Key())
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }
val_value, left, err = deserializeValue(ctx, left, t.Elem()) val_value, left, err = DeserializeValue(ctx, left, t.Elem())
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
}
value.SetMapIndex(key_value, val_value)
} }
value.SetMapIndex(key_value, val_value) return value, left, nil
} }
return value, left, nil
case reflect.Struct: case reflect.Struct:
info, mapped := ctx.TypeTypes[t] info, mapped := ctx.Types[t]
if mapped { if mapped {
value := reflect.New(t).Elem() value := reflect.New(t).Elem()
@ -520,7 +700,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
if mapped { if mapped {
var field_val reflect.Value var field_val reflect.Value
var err error var err error
field_val, left, err = deserializeValue(ctx, left, field_info.Type) field_val, left, err = DeserializeValue(ctx, left, field_info.Type)
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }
@ -544,7 +724,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value,
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }
elem_val, left, err := deserializeValue(ctx, rest, elem_type) elem_val, left, err := DeserializeValue(ctx, rest, elem_type)
if err != nil { if err != nil {
return reflect.Value{}, nil, err return reflect.Value{}, nil, err
} }

@ -7,10 +7,13 @@ import (
) )
func testTypeStack[T any](t *testing.T, ctx *Context) { func testTypeStack[T any](t *testing.T, ctx *Context) {
buffer := [1024]byte{}
reflect_type := reflect.TypeFor[T]() reflect_type := reflect.TypeFor[T]()
stack, err := TypeStack(ctx, reflect_type) written, err := TypeStack(ctx, reflect_type, buffer[:])
fatalErr(t, err) fatalErr(t, err)
stack := buffer[:written]
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack) ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
unwrapped_type, rest, err := UnwrapStack(ctx, stack) unwrapped_type, rest, err := UnwrapStack(ctx, stack)
@ -41,9 +44,12 @@ func TestSerializeTypes(t *testing.T) {
} }
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) { func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
serialized, err := Serialize(ctx, value) buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err) fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized) ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
deserialized, err := Deserialize[T](ctx, serialized) deserialized, err := Deserialize[T](ctx, serialized)
@ -57,9 +63,12 @@ func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
} }
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) { func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
serialized, err := Serialize(ctx, value) buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err) fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized) ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
deserialized, err := Deserialize[L](ctx, serialized) deserialized, err := Deserialize[L](ctx, serialized)
@ -75,9 +84,13 @@ func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L)
} }
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) { func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
serialized, err := Serialize(ctx, value) buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err) fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized) ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
deserialized, err := Deserialize[P](ctx, serialized) deserialized, err := Deserialize[P](ctx, serialized)
@ -97,9 +110,12 @@ func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Con
} }
func testSerialize[T any](t *testing.T, ctx *Context, value T) { func testSerialize[T any](t *testing.T, ctx *Context, value T) {
serialized, err := Serialize(ctx, value) buffer := [1024]byte{}
written, err := Serialize(ctx, value, buffer[:])
fatalErr(t, err) fatalErr(t, err)
serialized := buffer[:written]
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized) ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
deserialized, err := Deserialize[T](ctx, serialized) deserialized, err := Deserialize[T](ctx, serialized)
@ -144,7 +160,17 @@ func TestSerializeValues(t *testing.T) {
testSerializeCompare[*int](t, ctx, nil) testSerializeCompare[*int](t, ctx, nil)
testSerializeCompare(t, ctx, "string") testSerializeCompare(t, ctx, "string")
node, err := NewNode(ctx, nil, "Base", 100) testSerialize(t, ctx, map[string]string{
"Test": "Test",
"key": "String",
"": "",
})
testSerialize[map[string]string](t, ctx, nil)
testSerialize(t, ctx, NewListenerExt(10))
node, err := ctx.NewNode(nil, "Node")
fatalErr(t, err) fatalErr(t, err)
testSerialize(t, ctx, node) testSerialize(t, ctx, node)
} }

@ -21,13 +21,6 @@ func (signal TimeoutSignal) String() string {
return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader) return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader)
} }
type SignalDirection int
const (
Up SignalDirection = iota
Down
Direct
)
type SignalHeader struct { type SignalHeader struct {
Id uuid.UUID `gv:"id"` Id uuid.UUID `gv:"id"`
} }
@ -37,7 +30,7 @@ func (signal SignalHeader) ID() uuid.UUID {
} }
func (header SignalHeader) String() string { func (header SignalHeader) String() string {
return fmt.Sprintf("SignalHeader(%s)", header.Id) return fmt.Sprintf("%s", header.Id)
} }
type ResponseSignal interface { type ResponseSignal interface {
@ -55,7 +48,7 @@ func (header ResponseHeader) ResponseID() uuid.UUID {
} }
func (header ResponseHeader) String() string { func (header ResponseHeader) String() string {
return fmt.Sprintf("ResponseHeader(%s, %s)", header.Id, header.ReqID) return fmt.Sprintf("%s for %s", header.Id, header.ReqID)
} }
type Signal interface { type Signal interface {
@ -171,16 +164,16 @@ func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
type StatusSignal struct { type StatusSignal struct {
SignalHeader SignalHeader
Source NodeID `gv:"source"` Source NodeID `gv:"source"`
Changes map[ExtType]Changes `gv:"changes"` Fields []string `gv:"fields"`
} }
func (signal StatusSignal) String() string { func (signal StatusSignal) String() string {
return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes) return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Fields)
} }
func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal { func NewStatusSignal(source NodeID, fields []string) *StatusSignal {
return &StatusSignal{ return &StatusSignal{
NewSignalHeader(), NewSignalHeader(),
source, source,
changes, fields,
} }
} }
@ -205,32 +198,44 @@ func NewLinkSignal(action string, id NodeID) Signal {
type LockSignal struct { type LockSignal struct {
SignalHeader SignalHeader
State string
} }
func (signal LockSignal) String() string { func (signal LockSignal) String() string {
return fmt.Sprintf("LockSignal(%s, %s)", signal.SignalHeader, signal.State) return fmt.Sprintf("LockSignal(%s)", signal.SignalHeader)
} }
func NewLockSignal(state string) *LockSignal { func NewLockSignal() *LockSignal {
return &LockSignal{ return &LockSignal{
NewSignalHeader(), NewSignalHeader(),
state,
} }
} }
type UnlockSignal struct {
SignalHeader
}
func (signal UnlockSignal) String() string {
return fmt.Sprintf("UnlockSignal(%s)", signal.SignalHeader)
}
func NewUnlockSignal() *UnlockSignal {
return &UnlockSignal{
NewSignalHeader(),
}
}
type ReadSignal struct { type ReadSignal struct {
SignalHeader SignalHeader
Extensions map[ExtType][]string `json:"extensions"` Fields []string `json:"extensions"`
} }
func (signal ReadSignal) String() string { func (signal ReadSignal) String() string {
return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions) return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Fields)
} }
func NewReadSignal(exts map[ExtType][]string) *ReadSignal { func NewReadSignal(fields []string) *ReadSignal {
return &ReadSignal{ return &ReadSignal{
NewSignalHeader(), NewSignalHeader(),
exts, fields,
} }
} }
@ -238,19 +243,19 @@ type ReadResultSignal struct {
ResponseHeader ResponseHeader
NodeID NodeID NodeID NodeID
NodeType NodeType NodeType NodeType
Extensions map[ExtType]map[string]any Fields map[string]any
} }
func (signal ReadResultSignal) String() string { func (signal ReadResultSignal) String() string {
return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Extensions) return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Fields)
} }
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]any) *ReadResultSignal { func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, fields map[string]any) *ReadResultSignal {
return &ReadResultSignal{ return &ReadResultSignal{
NewResponseHeader(req_id), NewResponseHeader(req_id),
node_id, node_id,
node_type, node_type,
exts, fields,
} }
} }