diff --git a/gql_test.go b/gql_test.go index d63d13e..b5f5ac4 100644 --- a/gql_test.go +++ b/gql_test.go @@ -244,7 +244,7 @@ func TestGQLServer(t *testing.T) { } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{"test", "node"}) + ctx := logTestContext(t, []string{"test", "db", "node"}) TestUserNodeType := NewNodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) @@ -277,6 +277,7 @@ func TestGQLDB(t *testing.T) { ctx.nodeMap = map[NodeID]*Node{} gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) + listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType) fatalErr(t, err) msgs = Messages{} diff --git a/node.go b/node.go index 0291cd4..206505a 100644 --- a/node.go +++ b/node.go @@ -1,16 +1,17 @@ package graphvent import ( - "time" - "errors" - "reflect" - "github.com/google/uuid" - badger "github.com/dgraph-io/badger/v3" - "fmt" - "sync/atomic" "crypto/ed25519" - "crypto/sha512" "crypto/rand" + "crypto/sha512" + "encoding/binary" + "fmt" + "reflect" + "sync/atomic" + "time" + + badger "github.com/dgraph-io/badger/v3" + "github.com/google/uuid" ) var ( @@ -100,7 +101,7 @@ type Node struct { ID NodeID Type NodeType `gv:"type"` // TODO: move each extension to it's own db key, and extend changes to notify which extension was changed - Extensions map[ExtType]Extension `gv:"extensions"` + Extensions map[ExtType]Extension Policies []Policy `gv:"policies"` PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` @@ -117,20 +118,18 @@ type Node struct { // TODO: enhance WriteNode to write SignalQueue to a different key, and use writeSignalQueue to decide whether or not to update it writeSignalQueue bool - SignalQueue []QueuedSignal `gv:"signal_queue"` + SignalQueue []QueuedSignal NextSignal *QueuedSignal } func (node *Node) PostDeserialize(ctx *Context) error { + node.Extensions = map[ExtType]Extension{} + public := node.Key.Public().(ed25519.PublicKey) node.ID = KeyID(public) node.MsgChan = make(chan *Message, node.BufferSize) - node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) - ctx.Log.Logf("node", "signal_queue: %+v", node.SignalQueue) - ctx.Log.Logf("node", "next_signal: %+v - %+v", node.NextSignal, node.TimeoutChan) - return nil } @@ -294,10 +293,6 @@ func nodeLoop(ctx *Context, node *Node) error { // Perform startup actions node.Process(ctx, ZeroID, NewStartSignal()) - err := WriteNode(ctx, node) - if err != nil { - panic(err) - } run := true for run == true { var signal Signal @@ -418,11 +413,6 @@ func nodeLoop(ctx *Context, node *Node) error { } else { ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time) } - - err = WriteNode(ctx, node) - if err != nil { - ctx.Log.Logf("node", "Node Write Error: %s", err) - } } ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) @@ -637,6 +627,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si return nil, fmt.Errorf("Node type %+v not registered in Context", node_type) } + changes := Changes{} ext_map := map[ExtType]Extension{} for _, ext := range(extensions) { ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext)] @@ -648,6 +639,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si return nil, fmt.Errorf("Cannot add the same extension to a node twice") } ext_map[ext_type] = ext + changes.Add(ext_type, "init") } for _, required_ext := range(def.Extensions) { @@ -671,14 +663,21 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si BufferSize: buffer_size, SignalQueue: []QueuedSignal{}, } - ctx.AddNode(id, node) - err = node.Process(ctx, ZeroID, NewCreateSignal()) + err = WriteNodeExtList(ctx, node) + if err != nil { + return nil, err + } + + node.writeSignalQueue = true + err = WriteNodeChanges(ctx, node, changes) if err != nil { return nil, err } - err = WriteNode(ctx, node) + ctx.AddNode(id, node) + + err = node.Process(ctx, ZeroID, NewCreateSignal()) if err != nil { return nil, err } @@ -688,93 +687,215 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si return node, nil } -func WriteNodeChanges(ctx *Context, node *Node, changes Changes) error { - // TODO: optimize to not re-serialize unchanged extensions/fields(might need to cache the serialized values) - return WriteNode(ctx, node) +var extension_suffix = []byte{0xEE, 0xFF, 0xEE, 0xFF} +var signal_queue_suffix = []byte{0xAB, 0xBA, 0xAB, 0xBA} +func ExtTypeSuffix(ext_type ExtType) []byte { + ret := make([]byte, 12) + copy(ret[0:4], extension_suffix) + binary.BigEndian.PutUint64(ret[4:], uint64(ext_type)) + return ret } -// Write a node to the database -func WriteNode(ctx *Context, node *Node) error { - ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) +func WriteNodeExtList(ctx *Context, node *Node) error { + ext_list := make([]ExtType, len(node.Extensions)) + i := 0 + for ext_type := range(node.Extensions) { + ext_list[i] = ext_type + i += 1 + } - node_serialized, err := SerializeAny(ctx, node) + id_bytes, err := node.ID.MarshalBinary() if err != nil { return err } - chunks, err := node_serialized.Chunks() + + ext_list_serialized, err := SerializeAny(ctx, ext_list) if err != nil { return err } - ctx.Log.Logf("db_data", "DB_DATA: %+v", chunks.Slice()) + return ctx.DB.Update(func(txn *badger.Txn) error { + return txn.Set(append(id_bytes, extension_suffix...), ext_list_serialized.Data) + }) +} - id_bytes, err := node.ID.MarshalBinary() +func WriteNodeChanges(ctx *Context, node *Node, changes Changes) error { + ctx.Log.Logf("db", "Writing changes for %s - %+v", node.ID, changes) + + ext_serialized := map[ExtType]SerializedValue{} + for ext_type := range(changes) { + ext, ext_exists := node.Extensions[ext_type] + if ext_exists == false { + ctx.Log.Logf("db", "extension 0x%x does not exist for %s", ext_type, node.ID) + } else { + serialized_ext, err := SerializeAny(ctx, ext) + if err != nil { + return err + } + ext_serialized[ext_type] = serialized_ext + ctx.Log.Logf("db", "extension 0x%x - %+v - %+v", ext_type, serialized_ext.TypeStack, serialized_ext.Data) + } + } + + var sq_serialized *SerializedValue = nil + if node.writeSignalQueue == true { + node.writeSignalQueue = false + ser, err := SerializeAny(ctx, node.SignalQueue) + if err != nil { + return err + } + sq_serialized = &ser + } + + node_serialized, err := SerializeAny(ctx, node) if err != nil { return err } - ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes) + id_bytes, err := node.ID.MarshalBinary() return ctx.DB.Update(func(txn *badger.Txn) error { - return txn.Set(id_bytes, chunks.Slice()) + err := txn.Set(id_bytes, node_serialized.Data) + if err != nil { + return err + } + if sq_serialized != nil { + err := txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data) + if err != nil { + return err + } + } + for ext_type, data := range(ext_serialized) { + err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data) + if err != nil { + return err + } + } + return nil }) } -func LoadNode(ctx * Context, id NodeID) (*Node, error) { +func LoadNode(ctx *Context, id NodeID) (*Node, error) { ctx.Log.Logf("db", "LOADING_NODE: %s", id) - var bytes []byte + var node_bytes []byte = nil + var sq_bytes []byte = nil + var ext_bytes = map[ExtType][]byte{} + err := ctx.DB.View(func(txn *badger.Txn) error { id_bytes, err := id.MarshalBinary() if err != nil { return err } - ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes) - item, err := txn.Get(id_bytes) + + node_item, err := txn.Get(id_bytes) + if err != nil { + ctx.Log.Logf("db", "node key not found") + return err + } + + node_bytes, err = node_item.ValueCopy(nil) if err != nil { return err } - return item.Value(func(val []byte) error { - bytes = append([]byte{}, val...) - return nil - }) + sq_item, err := txn.Get(append(id_bytes, signal_queue_suffix...)) + if err != nil { + ctx.Log.Logf("db", "sq key not found") + return err + } + sq_bytes, err = sq_item.ValueCopy(nil) + if err != nil { + return err + } + + ext_list_item, err := txn.Get(append(id_bytes, extension_suffix...)) + if err != nil { + ctx.Log.Logf("db", "ext_list key not found") + return err + } + + ext_list_bytes, err := ext_list_item.ValueCopy(nil) + if err != nil { + return err + } + + ext_list_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]ExtType{}), ext_list_bytes) + if err != nil { + return err + } else if len(remaining) > 0 { + return fmt.Errorf("Data remaining after ext_list deserialize %d", len(remaining)) + } + ext_list, ok := ext_list_value.Interface().([]ExtType) + if ok == false { + return fmt.Errorf("deserialize returned wrong type %s", ext_list_value.Type()) + } + + for _, ext_type := range(ext_list) { + ext_item, err := txn.Get(append(id_bytes, ExtTypeSuffix(ext_type)...)) + if err != nil { + ctx.Log.Logf("db", "ext %s key not found", ext_type) + return err + } + + ext_bytes[ext_type], err = ext_item.ValueCopy(nil) + if err != nil { + return err + } + } + return nil }) - if errors.Is(err, badger.ErrKeyNotFound) { - return nil, NodeNotFoundError - }else if err != nil { + if err != nil { return nil, err } - value, remaining, err := ParseSerializedValue(bytes) + node_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf((*Node)(nil)), node_bytes) if err != nil { return nil, err } else if len(remaining) != 0 { - return nil, fmt.Errorf("%d bytes left after parsing node from DB", len(remaining)) + return nil, fmt.Errorf("data left after deserializing node %d", len(remaining)) } - node_type, remaining_types, err := DeserializeType(ctx, value.TypeStack) - if err != nil { - return nil, err - } else if len(remaining_types) != 0 { - return nil, fmt.Errorf("%d entries left in typestack after deserializing *Node", len(remaining_types)) + + node, node_ok := node_value.Interface().(*Node) + if node_ok == false { + return nil, fmt.Errorf("node wrong type %s", node_value.Type()) } - node_val, remaining_data, err := DeserializeValue(ctx, node_type, value.Data) + signal_queue_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]QueuedSignal{}), sq_bytes) if err != nil { return nil, err - } else if len(remaining_data) != 0 { - return nil, fmt.Errorf("%d bytes left after desrializing *Node", len(remaining_data)) + } else if len(remaining) != 0 { + return nil, fmt.Errorf("data left after deserializing signal_queue %d", len(remaining)) } - node, ok := node_val.Interface().(*Node) - if ok == false { - return nil, fmt.Errorf("Deserialized %+v when expecting *Node", node_val.Type()) + signal_queue, sq_ok := signal_queue_value.Interface().([]QueuedSignal) + if sq_ok == false { + return nil, fmt.Errorf("signal queue wrong type %s", signal_queue_value.Type()) } - for ext_type, ext := range(node.Extensions){ - ctx.Log.Logf("serialize", "Deserialized extension: %+v - %+v", ext_type, ext) + for ext_type, data := range(ext_bytes) { + ext_info, exists := ctx.Extensions[ext_type] + if exists == false { + return nil, fmt.Errorf("0x%0x is not a known extension type", ext_type) + } + + ext_value, remaining, err := DeserializeValue(ctx, ext_info.Type, data) + if err != nil { + return nil, err + } else if len(remaining) > 0 { + return nil, fmt.Errorf("data left after deserializing ext(0x%x) %d", ext_type, len(remaining)) + } + ext, ext_ok := ext_value.Interface().(Extension) + if ext_ok == false { + return nil, fmt.Errorf("extension wrong type %s", ext_value.Type()) + } + + node.Extensions[ext_type] = ext } + node.SignalQueue = signal_queue + node.NextSignal, node.TimeoutChan = SoonestSignal(signal_queue) + ctx.AddNode(id, node) - ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) + ctx.Log.Logf("db", "loaded %s", id) go runNode(ctx, node) return node, nil diff --git a/node_test.go b/node_test.go index 69087d5..3fa61c8 100644 --- a/node_test.go +++ b/node_test.go @@ -9,7 +9,7 @@ import ( ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{"signal", "node", "db", "listener"}) + ctx := logTestContext(t, []string{"signal", "serialize", "node", "db", "listener"}) node_type := NewNodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err)