master
noah metz 2024-03-31 18:58:27 -07:00
parent c29981da20
commit 8cb97d2350
10 changed files with 224 additions and 252 deletions

@ -2,7 +2,9 @@ package graphvent
import ( import (
"crypto/ecdh" "crypto/ecdh"
"crypto/ed25519"
"encoding" "encoding"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -105,8 +107,7 @@ type Context struct {
// Map between database node type hashes and the registered info // Map between database node type hashes and the registered info
NodeTypes map[NodeType]NodeInfo NodeTypes map[NodeType]NodeInfo
// Routing map to all the nodes local to this context nodesLock sync.Mutex
nodesLock sync.RWMutex
nodes map[NodeID]ContextNode nodes map[NodeID]ContextNode
running atomic.Bool running atomic.Bool
@ -838,103 +839,140 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f
return nil return nil
} }
func (ctx *Context) AddNode(id NodeID, node *Node, status chan string, command chan string) { func (ctx *Context) NewNode(key ed25519.PrivateKey, type_name string, extensions ...Extension) (*Node, error) {
ctx.nodesLock.Lock() ctx.nodesLock.Lock()
ctx.nodes[id] = ContextNode{ defer ctx.nodesLock.Unlock()
Node: node,
Status: status,
Command: command,
}
ctx.nodesLock.Unlock()
}
func (ctx *Context) Node(id NodeID) (*Node, bool) { node_type := NodeTypeFor(type_name)
ctx.nodesLock.RLock() node_info, known_type := ctx.NodeTypes[node_type]
node, exists := ctx.nodes[id] if known_type == false {
ctx.nodesLock.RUnlock() return nil, fmt.Errorf("%s is not a known node type", type_name)
return node.Node, exists }
}
func (ctx *Context) Delete(id NodeID) error { var err error
err := ctx.Unload(id) var public ed25519.PublicKey
if key == nil {
public, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return err return nil, err
}
} else {
public = key.Public().(ed25519.PublicKey)
}
id := KeyID(public)
_, err = ctx.getNode(id)
if err == nil {
return nil, fmt.Errorf("Attempted to create an existing node")
} else if errors.Is(err, NodeNotFoundError) == false {
return nil, fmt.Errorf("Error checking if node exists: %+w", err)
} }
// TODO: also delete any associated data from DB
return nil
}
func (ctx *Context) Unload(id NodeID) error { ext_map := map[ExtType]Extension{}
ctx.nodesLock.Lock() for _, ext := range(extensions) {
defer ctx.nodesLock.Unlock() if ext == nil {
node, exists := ctx.nodes[id] return nil, fmt.Errorf("Cannot create node with nil extension")
if exists == false {
return fmt.Errorf("%s is not a node in ctx", id)
} }
node.Command <- "stop" ext_type, exists := ctx.Extensions[ExtTypeOf(reflect.TypeOf(ext))]
returned := <- node.Status if exists == false {
return nil, fmt.Errorf("%+v(%+v) is not a known Extension", reflect.TypeOf(ext), ExtTypeOf(reflect.TypeOf(ext)))
}
_, exists = ext_map[ext_type.ExtType]
if exists == true {
return nil, fmt.Errorf("Cannot add the same extension to a node twice")
}
ext_map[ext_type.ExtType] = ext
}
if returned != "stopped" { for _, required_ext := range(node_info.RequiredExtensions) {
return fmt.Errorf(returned) _, exists := ext_map[required_ext]
if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
}
} }
err := node.Node.Unload(ctx) node := &Node{
delete(ctx.nodes, id) Key: key,
return err ID: id,
} Type: node_type,
Extensions: ext_map,
SignalQueue: []QueuedSignal{},
writeSignalQueue: false,
}
func (ctx *Context) Stop() { node.SendChan, node.RecvChan = NewMessageQueue(NODE_INITIAL_QUEUE_SIZE)
ctx.nodesLock.Lock()
for id, node := range(ctx.nodes) {
node.Command <- "stop"
returned := <- node.Status
if returned != "stopped" { err = ctx.DB.WriteNodeInit(ctx, node)
ctx.Log.Logf("node", "Node returned %s when commanded to stop", returned)
} else {
err := node.Node.Unload(ctx)
if err != nil { if err != nil {
ctx.Log.Logf("node", "Error unloading %s: %s", id, err) return nil, err
}
delete(ctx.nodes, id)
}
} }
ctx.nodesLock.Unlock()
}
func (ctx *Context) Load(id NodeID) (*Node, error) { err = ctx.addNode(id, node)
node, err := ctx.DB.LoadNode(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return node, nil
}
func (ctx *Context) addNode(id NodeID, node *Node) error {
status := make(chan string, 0) status := make(chan string, 0)
command := make(chan string, 0) command := make(chan string, 0)
go runNode(ctx, node, status, command) go runNode(ctx, node, status, command)
returned := <- status returned := <- status
if returned != "active" { if returned != "active" {
return nil, fmt.Errorf(returned) return fmt.Errorf(returned)
}
ctx.nodes[id] = ContextNode{
Node: node,
Status: status,
Command: command,
} }
return nil
}
ctx.AddNode(id, node, status, command) func (ctx *Context) Stop() error {
ctx.nodesLock.Lock()
defer ctx.nodesLock.Unlock()
return node, nil for _, node := range(ctx.nodes) {
node.Command <- "stop"
returned := <- node.Status
if returned != "stopped" {
return fmt.Errorf("Node returned %s when commanded to stop", returned)
}
}
ctx.nodes = map[NodeID]ContextNode{}
return nil
}
func (ctx *Context) GetNode(id NodeID) (*Node, error) {
ctx.nodesLock.Lock()
defer ctx.nodesLock.Unlock()
return ctx.getNode(id)
} }
// Get a node from the context, or load from the database if not loaded
func (ctx *Context) getNode(id NodeID) (*Node, error) { func (ctx *Context) getNode(id NodeID) (*Node, error) {
target, exists := ctx.Node(id) target, exists := ctx.nodes[id]
if exists == false { if exists == false {
var err error node, err := ctx.DB.LoadNode(ctx, id)
target, err = ctx.Load(id) if err != nil {
return nil, err
}
err = ctx.addNode(id, node)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to load node %s: %w", id, err) return nil, err
} }
return node, nil
} else {
return target.Node, nil
} }
return target, nil
} }
// Route Messages to dest. Currently only local context routing is supported // Route Messages to dest. Currently only local context routing is supported

89
db.go

@ -84,18 +84,34 @@ func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
// For each extension: // For each extension:
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
// Write each extension's current value ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
}
ext_value := reflect.ValueOf(ext).Elem()
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:])
// Write each field to a seperate key
for field_tag, field_info := range(ext_info.Fields) {
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil { if err != nil {
return err return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
} }
err = tx.Set(ext_id, db.buffer[cur:cur+written])
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
} }
cur += written cur += written
} }
}
return nil return nil
}) })
} }
@ -127,24 +143,44 @@ func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtTy
} }
// For each ext in changes // For each ext in changes
for ext_type := range(changes) { for ext_type, changes := range(changes) {
// Write each ext ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext, exists := node.Extensions[ext_type] ext, exists := node.Extensions[ext_type]
if exists == false { if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
} }
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:]) ext_value := reflect.ValueOf(ext)
// Write each field
for _, tag := range(changes) {
field_info, exists := ext_info.Fields[tag]
if exists == false {
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
}
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil { if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err) return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
} }
err = tx.Set(ext_id, db.buffer[cur:cur+written]) err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err) return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
} }
cur += written cur += written
} }
}
return nil return nil
}) })
} }
@ -156,13 +192,13 @@ func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
// Get the base key bytes // Get the base key bytes
id_ser, err := id.MarshalBinary() id_ser, err := id.MarshalBinary()
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to serialize node_id: %w", err)
} }
// Get the node value // Get the node value
node_item, err := tx.Get(id_ser) node_item, err := tx.Get(id_ser)
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
} }
err = node_item.Value(func(val []byte) error { err = node_item.Value(func(val []byte) error {
@ -179,14 +215,14 @@ func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_item, err := tx.Get(sigqueue_id) sigqueue_item, err := tx.Get(sigqueue_id)
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to get sigqueue_id: %w", err)
} }
err = sigqueue_item.Value(func(val []byte) error { err = sigqueue_item.Value(func(val []byte) error {
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val) node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
return err return err
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
} }
// Get the extension list // Get the extension list
@ -205,35 +241,34 @@ func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
// Get the extensions // Get the extensions
for _, ext_type := range(ext_list) { for _, ext_type := range(ext_list) {
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_item, err := tx.Get(ext_id)
if err != nil {
return err
}
ext_info, exists := ctx.Extensions[ext_type] ext_info, exists := ctx.Extensions[ext_type]
if exists == false { if exists == false {
return fmt.Errorf("Extension %s not in context", ext_type) return fmt.Errorf("Extension %s not in context", ext_type)
} }
var ext Extension ext := reflect.New(ext_info.Type)
var ok bool for field_tag, field_info := range(ext_info.Fields) {
err = ext_item.Value(func(val []byte) error { field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
value, _, err := DeserializeValue(ctx, val, ext_info.Type) field_item, err := tx.Get(field_id)
if err != nil {
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
}
err = field_item.Value(func(val []byte) error {
value, _, err := DeserializeValue(ctx, val, field_info.Type)
if err != nil { if err != nil {
return err return err
} }
ext, ok = value.Addr().Interface().(Extension) ext.Elem().FieldByIndex(field_info.Index).Set(value)
if ok == false {
return fmt.Errorf("Parsed value %+v is not extension", value.Type())
}
return nil return nil
}) })
if err != nil { if err != nil {
return err return err
} }
node.Extensions[ext_type] = ext }
node.Extensions[ext_type] = ext.Interface().(Extension)
} }
return nil return nil

@ -542,6 +542,8 @@ type GQLExt struct {
func (ext *GQLExt) Load(ctx *Context, node *Node) error { func (ext *GQLExt) Load(ctx *Context, node *Node) error {
ctx.Log.Logf("gql", "Loading GQL server extension on %s", node.ID) ctx.Log.Logf("gql", "Loading GQL server extension on %s", node.ID)
ext.resolver_response = map[uuid.UUID]chan Signal{}
ext.subscriptions = []SubscriptionInfo{}
return ext.StartGQLServer(ctx, node) return ext.StartGQLServer(ctx, node)
} }
@ -555,13 +557,6 @@ func (ext *GQLExt) Unload(ctx *Context, node *Node) {
} }
} }
func (ext *GQLExt) PostDeserialize(*Context) error {
ext.resolver_response = map[uuid.UUID]chan Signal{}
ext.subscriptions = []SubscriptionInfo{}
return nil
}
func (ext *GQLExt) AddSubscription(id uuid.UUID, ctx *ResolveContext) (chan interface{}, error) { func (ext *GQLExt) AddSubscription(id uuid.UUID, ctx *ResolveContext) (chan interface{}, error) {
ext.subscriptions_lock.Lock() ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock() defer ext.subscriptions_lock.Unlock()

@ -19,7 +19,7 @@ import (
func TestGQLSubscribe(t *testing.T) { func TestGQLSubscribe(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gql"}) ctx := logTestContext(t, []string{"test", "gql"})
n1, err := NewNode(ctx, nil, "LockableNode", NewLockableExt(nil)) n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
@ -27,7 +27,7 @@ func TestGQLSubscribe(t *testing.T) {
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
gql, err := NewNode(ctx, nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext) gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
fatalErr(t, err) fatalErr(t, err)
query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }" query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }"
@ -129,14 +129,14 @@ func TestGQLQuery(t *testing.T) {
ctx := logTestContext(t, []string{"test", "lockable"}) ctx := logTestContext(t, []string{"test", "lockable"})
n1_listener := NewListenerExt(10) n1_listener := NewListenerExt(10)
n1, err := NewNode(ctx, nil, "LockableNode", NewLockableExt(nil), n1_listener) n1, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil), n1_listener)
fatalErr(t, err) fatalErr(t, err)
gql_listener := NewListenerExt(10) gql_listener := NewListenerExt(10)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
gql, err := NewNode(ctx, nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener) gql, err := ctx.NewNode(nil, "LockableNode", NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "GQL: %s", gql.ID)
@ -208,14 +208,14 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql, err := NewNode(ctx, nil, "Node", gql_ext, listener_ext) gql, err := ctx.NewNode(nil, "Node", gql_ext, listener_ext)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Unload(gql.ID) err = ctx.Stop()
fatalErr(t, err) fatalErr(t, err)
gql_loaded, err := ctx.Load(gql.ID) gql_loaded, err := ctx.GetNode(gql.ID)
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[ListenerExt](gql_loaded) listener_ext, err = GetExt[ListenerExt](gql_loaded)

@ -9,18 +9,13 @@ import (
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) { func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) {
listener_extension := NewListenerExt(buffer) listener_extension := NewListenerExt(buffer)
listener, err := NewNode(ctx, listener, err := ctx.NewNode(nil, "LockableNode", nil, listener_extension, NewLockableExt(nil))
nil,
"LockableNode",
nil,
listener_extension,
NewLockableExt(nil))
return listener, listener_extension, err return listener, listener_extension, err
} }
func logTestContext(t * testing.T, components []string) *Context { func logTestContext(t * testing.T, components []string) *Context {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true).WithSyncWrites(true))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -44,20 +44,6 @@ type LockableExt struct{
Waiting WaitMap `gv:"waiting_locks" node:":Lockable"` Waiting WaitMap `gv:"waiting_locks" node:":Lockable"`
} }
func (ext *LockableExt) PostDeserialize(ctx *Context) error {
ext.Locked = map[NodeID]any{}
ext.Unlocked = map[NodeID]any{}
for id, state := range(ext.Requirements) {
if state == Unlocked {
ext.Unlocked[id] = nil
} else if state == Locked {
ext.Locked[id] = nil
}
}
return nil
}
func NewLockableExt(requirements []NodeID) *LockableExt { func NewLockableExt(requirements []NodeID) *LockableExt {
var reqs map[NodeID]ReqState = nil var reqs map[NodeID]ReqState = nil
var unlocked map[NodeID]any = map[NodeID]any{} var unlocked map[NodeID]any = map[NodeID]any{}
@ -95,6 +81,16 @@ func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
} }
func (ext *LockableExt) Load(ctx *Context, node *Node) error { func (ext *LockableExt) Load(ctx *Context, node *Node) error {
ext.Locked = map[NodeID]any{}
ext.Unlocked = map[NodeID]any{}
for id, state := range(ext.Requirements) {
if state == Unlocked {
ext.Unlocked[id] = nil
} else if state == Locked {
ext.Locked[id] = nil
}
}
return nil return nil
} }

@ -10,12 +10,12 @@ func TestLink(t *testing.T) {
l2_listener := NewListenerExt(10) l2_listener := NewListenerExt(10)
l2, err := NewNode(ctx, nil, "LockableNode", l2_listener, NewLockableExt(nil)) l2, err := ctx.NewNode(nil, "LockableNode", l2_listener, NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
l1_lockable := NewLockableExt(nil) l1_lockable := NewLockableExt(nil)
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
l1, err := NewNode(ctx, nil, "LockableNode", l1_listener, l1_lockable) l1, err := ctx.NewNode(nil, "LockableNode", l1_listener, l1_lockable)
fatalErr(t, err) fatalErr(t, err)
link_signal := NewLinkSignal("add", l2.ID) link_signal := NewLinkSignal("add", l2.ID)
@ -62,7 +62,7 @@ func testLockN(t *testing.T, n int) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
l, err := NewNode(ctx, nil, "LockableNode", NewLockableExt(nil)) l, err := ctx.NewNode(nil, "LockableNode", NewLockableExt(nil))
fatalErr(t, err) fatalErr(t, err)
return l return l
} }
@ -75,7 +75,7 @@ func testLockN(t *testing.T, n int) {
ctx.Log.Logf("test", "CREATED_%d", n) ctx.Log.Logf("test", "CREATED_%d", n)
listener := NewListenerExt(50000) listener := NewListenerExt(50000)
node, err := NewNode(ctx, nil, "LockableNode", listener, NewLockableExt(reqs)) node, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "CREATED_LISTENER") ctx.Log.Logf("test", "CREATED_LISTENER")
@ -99,7 +99,7 @@ func TestLock(t *testing.T) {
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
listener := NewListenerExt(10000) listener := NewListenerExt(10000)
l, err := NewNode(ctx, nil, "LockableNode", listener, NewLockableExt(reqs)) l, err := ctx.NewNode(nil, "LockableNode", listener, NewLockableExt(reqs))
fatalErr(t, err) fatalErr(t, err)
return l, listener return l, listener
} }

@ -2,7 +2,6 @@ package graphvent
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand"
"crypto/sha512" "crypto/sha512"
"fmt" "fmt"
"reflect" "reflect"
@ -299,15 +298,12 @@ func nodeLoop(ctx *Context, node *Node, status chan string, control chan string)
panic("BAD_STATE: stopping already stopped node") panic("BAD_STATE: stopping already stopped node")
} }
status <- "stopped"
return nil
}
func (node *Node) Unload(ctx *Context) error {
for _, extension := range(node.Extensions) { for _, extension := range(node.Extensions) {
extension.Unload(ctx, node) extension.Unload(ctx, node)
} }
status <- "stopped"
return nil return nil
} }
@ -403,81 +399,3 @@ func KeyID(pub ed25519.PublicKey) NodeID {
id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID(id) return NodeID(id)
} }
// Create a new node in memory and start it's event loop
func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, extensions ...Extension) (*Node, error) {
node_type := NodeTypeFor(type_name)
node_info, known_type := ctx.NodeTypes[node_type]
if known_type == false {
return nil, fmt.Errorf("%s is not a known node type", type_name)
}
var err error
var public ed25519.PublicKey
if key == nil {
public, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
} else {
public = key.Public().(ed25519.PublicKey)
}
id := KeyID(public)
_, exists := ctx.Node(id)
if exists == true {
return nil, fmt.Errorf("Attempted to create an existing node")
}
ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) {
if ext == nil {
return nil, fmt.Errorf("Cannot create node with nil extension")
}
ext_type, exists := ctx.Extensions[ExtTypeOf(reflect.TypeOf(ext))]
if exists == false {
return nil, fmt.Errorf("%+v(%+v) is not a known Extension", reflect.TypeOf(ext), ExtTypeOf(reflect.TypeOf(ext)))
}
_, exists = ext_map[ext_type.ExtType]
if exists == true {
return nil, fmt.Errorf("Cannot add the same extension to a node twice")
}
ext_map[ext_type.ExtType] = ext
}
for _, required_ext := range(node_info.RequiredExtensions) {
_, exists := ext_map[required_ext]
if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
}
}
node := &Node{
Key: key,
ID: id,
Type: node_type,
Extensions: ext_map,
SignalQueue: []QueuedSignal{},
writeSignalQueue: false,
}
node.SendChan, node.RecvChan = NewMessageQueue(NODE_INITIAL_QUEUE_SIZE)
err = ctx.DB.WriteNodeInit(ctx, node)
if err != nil {
return nil, err
}
status := make(chan string, 0)
command := make(chan string, 0)
go runNode(ctx, node, status, command)
returned := <- status
if returned != "active" {
return nil, fmt.Errorf(returned)
}
ctx.AddNode(id, node, status, command)
return node, nil
}

@ -5,24 +5,19 @@ import (
"time" "time"
"crypto/rand" "crypto/rand"
"crypto/ed25519" "crypto/ed25519"
"slices"
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "node", "db"}) ctx := logTestContext(t, []string{"test", "node", "db"})
node_listener := NewListenerExt(10) node_listener := NewListenerExt(10)
node, err := NewNode(ctx, nil, "Node", NewLockableExt(nil), node_listener) node, err := ctx.NewNode(nil, "Node", NewLockableExt(nil), node_listener)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { err = ctx.Stop()
return slices.Contains(sig.Fields, "state") && sig.Source == node.ID
})
err = ctx.Unload(node.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = ctx.getNode(node.ID) _, err = ctx.GetNode(node.ID)
fatalErr(t, err) fatalErr(t, err)
} }
@ -41,10 +36,10 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2, err := NewNode(ctx, n2_key, "Node", n2_listener) n2, err := ctx.NewNode(n2_key, "Node", n2_listener)
fatalErr(t, err) fatalErr(t, err)
n1, err := NewNode(ctx, n1_key, "Node", NewListenerExt(10)) n1, err := ctx.NewNode(n1_key, "Node", NewListenerExt(10))
fatalErr(t, err) fatalErr(t, err)
read_sig := NewReadSignal([]string{"buffer"}) read_sig := NewReadSignal([]string{"buffer"})

@ -170,7 +170,7 @@ func TestSerializeValues(t *testing.T) {
testSerialize(t, ctx, NewListenerExt(10)) testSerialize(t, ctx, NewListenerExt(10))
node, err := NewNode(ctx, nil, "Node") node, err := ctx.NewNode(nil, "Node")
fatalErr(t, err) fatalErr(t, err)
testSerialize(t, ctx, node) testSerialize(t, ctx, node)
} }