From b2d84b24530bb5a8228d7e157d863161337d386c Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 30 Mar 2024 14:42:06 -0700 Subject: [PATCH] Moved db from badger to an interface --- context.go | 8 ++--- db.go | 90 ++++++++++++++++++++++++++++++++---------------- gql_test.go | 2 +- graph_test.go | 4 ++- lockable.go | 18 ++++------ lockable_test.go | 18 ++++++++-- log.go | 2 +- node.go | 7 ++-- serialize.go | 2 +- 9 files changed, 94 insertions(+), 57 deletions(-) diff --git a/context.go b/context.go index 86767c1..f4b6768 100644 --- a/context.go +++ b/context.go @@ -17,8 +17,6 @@ import ( "github.com/google/uuid" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" - - badger "github.com/dgraph-io/badger/v3" ) var ( @@ -82,7 +80,7 @@ type InterfaceInfo struct { type Context struct { // DB is the database connection used to load and write nodes - DB * badger.DB + DB Database // Logging interface Log Logger @@ -870,7 +868,7 @@ func (ctx *Context) Stop() { } func (ctx *Context) Load(id NodeID) (*Node, error) { - node, err := LoadNode(ctx, id) + node, err := ctx.DB.LoadNode(ctx, id) if err != nil { return nil, err } @@ -973,7 +971,7 @@ func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{ } // Create a new Context with the base library content added -func NewContext(db * badger.DB, log Logger) (*Context, error) { +func NewContext(db Database, log Logger) (*Context, error) { uuid.EnableRandPool() ctx := &Context{ diff --git a/db.go b/db.go index 910e76d..c089d18 100644 --- a/db.go +++ b/db.go @@ -3,100 +3,127 @@ package graphvent import ( "encoding/binary" "fmt" - "reflect" + "reflect" + "sync" badger "github.com/dgraph-io/badger/v3" ) -const NODE_BUFFER_SIZE = 1000000 +type Database interface { + WriteNodeInit(*Context, *Node) error + WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error + LoadNode(*Context, NodeID) (*Node, error) +} + +const WRITE_BUFFER_SIZE = 1000000 +type BadgerDB struct { + *badger.DB + sync.Mutex + buffer [WRITE_BUFFER_SIZE]byte +} -func WriteNodeInit(ctx *Context, node *Node) error { +func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error { if node == nil { return fmt.Errorf("Cannot serialize nil *Node") } - buffer := [NODE_BUFFER_SIZE]byte{} + return db.Update(func(tx *badger.Txn) error { + db.Lock() + defer db.Unlock() - return ctx.DB.Update(func(tx *badger.Txn) error { // Get the base key bytes id_ser, err := node.ID.MarshalBinary() if err != nil { return err } + cur := 0 + // Write Node value - written, err := Serialize(ctx, node, buffer[:]) + written, err := Serialize(ctx, node, db.buffer[cur:]) if err != nil { return err } - err = tx.Set(id_ser, buffer[:written]) + + err = tx.Set(id_ser, db.buffer[cur:cur+written]) if err != nil { return err } + + cur += written // Write empty signal queue sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) - written, err = Serialize(ctx, node.SignalQueue, buffer[:]) + written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:]) if err != nil { return err } - err = tx.Set(sigqueue_id, buffer[:written]) + + err = tx.Set(sigqueue_id, db.buffer[cur:cur+written]) if err != nil { return err } + cur += written + // Write node extension list ext_list := []ExtType{} for ext_type := range(node.Extensions) { ext_list = append(ext_list, ext_type) } - written, err = Serialize(ctx, ext_list, buffer[:]) + written, err = Serialize(ctx, ext_list, db.buffer[cur:]) if err != nil { return err } ext_list_id := append(id_ser, []byte(" - EXTLIST")...) - err = tx.Set(ext_list_id, buffer[:written]) + err = tx.Set(ext_list_id, db.buffer[cur:cur+written]) if err != nil { return err } + cur += written // For each extension: for ext_type, ext := range(node.Extensions) { // Write each extension's current value ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) - written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), buffer[:]) + written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:]) if err != nil { return err } - err = tx.Set(ext_id, buffer[:written]) + err = tx.Set(ext_id, db.buffer[cur:cur+written]) + if err != nil { + return err + } + cur += written } return nil }) } -func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { - buffer := [NODE_BUFFER_SIZE]byte{} +func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { + return db.Update(func(tx *badger.Txn) error { + db.Lock() + defer db.Unlock() - return ctx.DB.Update(func(tx *badger.Txn) error { // Get the base key bytes - id_ser, err := node.ID.MarshalBinary() - if err != nil { - return fmt.Errorf("Marshal ID error: %+w", err) - } + id_bytes := ([16]byte)(node.ID) + + cur := 0 // Write the signal queue if it needs to be written if node.writeSignalQueue { node.writeSignalQueue = false - sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) - written, err := Serialize(ctx, node.SignalQueue, buffer[:]) + sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...) + written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:]) if err != nil { return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err) } - err = tx.Set(sigqueue_id, buffer[:written]) + err = tx.Set(sigqueue_id, db.buffer[cur:cur+written]) if err != nil { return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err) } + cur += written } // For each ext in changes @@ -106,24 +133,26 @@ func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) err if exists == false { return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) } - ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) - written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), buffer[:]) + ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type)) + written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:]) if err != nil { return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err) } - err = tx.Set(ext_id, buffer[:written]) + err = tx.Set(ext_id, db.buffer[cur:cur+written]) if err != nil { return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err) } + cur += written } return nil }) } -func LoadNode(ctx *Context, id NodeID) (*Node, error) { +func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) { var node *Node = nil - err := ctx.DB.View(func(tx *badger.Txn) error { + + err := db.View(func(tx *badger.Txn) error { // Get the base key bytes id_ser, err := id.MarshalBinary() if err != nil { @@ -137,12 +166,13 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) { } err = node_item.Value(func(val []byte) error { + ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val) node, err = Deserialize[*Node](ctx, val) return err }) if err != nil { - return nil + return fmt.Errorf("Failed to deserialize Node %s - %w", id, err) } // Get the signal queue @@ -211,6 +241,8 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) { if err != nil { return nil, err + } else if node == nil { + return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error") } return node, nil diff --git a/gql_test.go b/gql_test.go index b9f1de7..5e52da4 100644 --- a/gql_test.go +++ b/gql_test.go @@ -202,7 +202,7 @@ func TestGQLQuery(t *testing.T) { } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{"test", "db", "node"}) + ctx := logTestContext(t, []string{"test", "db", "node", "serialize"}) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) diff --git a/graph_test.go b/graph_test.go index 5e8b6a1..f0e70e8 100644 --- a/graph_test.go +++ b/graph_test.go @@ -26,7 +26,9 @@ func logTestContext(t * testing.T, components []string) *Context { t.Fatal(err) } - ctx, err := NewContext(db, NewConsoleLogger(components)) + ctx, err := NewContext(&BadgerDB{ + DB: db, + }, NewConsoleLogger(components)) fatalErr(t, err) return ctx diff --git a/lockable.go b/lockable.go index 7f59f77..80c12d3 100644 --- a/lockable.go +++ b/lockable.go @@ -143,8 +143,7 @@ func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source Node ext.PendingOwner = nil - ext.ReqID = new(uuid.UUID) - *ext.ReqID = signal.Id + ext.ReqID = &signal.Id ext.State = Unlocking for id := range(ext.Requirements) { @@ -175,22 +174,18 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID if len(ext.Requirements) == 0 { changes = append(changes, "state", "owner", "pending_owner") - ext.Owner = new(NodeID) - *ext.Owner = source + ext.Owner = &source - ext.PendingOwner = new(NodeID) - *ext.PendingOwner = source + ext.PendingOwner = &source ext.State = Locked messages = append(messages, SendMsg{source, NewSuccessSignal(signal.Id)}) } else { changes = append(changes, "state", "requirements", "waiting", "pending_owner") - ext.PendingOwner = new(NodeID) - *ext.PendingOwner = source + ext.PendingOwner = &source - ext.ReqID = new(uuid.UUID) - *ext.ReqID = signal.Id + ext.ReqID = &signal.Id ext.State = Locking for id := range(ext.Requirements) { @@ -313,8 +308,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod changes = append(changes, "state", "owner", "req_id") ext.State = Locked - ext.Owner = new(NodeID) - *ext.Owner = *ext.PendingOwner + ext.Owner = ext.PendingOwner messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) ext.ReqID = nil diff --git a/lockable_test.go b/lockable_test.go index 32d0c8e..e495246 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -42,7 +42,19 @@ func TestLink(t *testing.T) { fatalErr(t, err) } +func Test10Lock(t *testing.T) { + testLockN(t, 10) +} + +func Test1000Lock(t *testing.T) { + testLockN(t, 1000) +} + func Test10000Lock(t *testing.T) { + testLockN(t, 10000) +} + +func testLockN(t *testing.T, n int) { ctx := logTestContext(t, []string{"test"}) NewLockable := func()(*Node) { @@ -51,12 +63,12 @@ func Test10000Lock(t *testing.T) { return l } - reqs := make([]NodeID, 10000) + reqs := make([]NodeID, n) for i := range(reqs) { new_lockable := NewLockable() reqs[i] = new_lockable.ID } - ctx.Log.Logf("test", "CREATED_10000") + ctx.Log.Logf("test", "CREATED_%d", n) listener := NewListenerExt(50000) node, err := NewNode(ctx, nil, "LockableNode", 500000, listener, NewLockableExt(reqs)) @@ -75,7 +87,7 @@ func Test10000Lock(t *testing.T) { t.Fatalf("Unexpected response to lock - %s", resp) } - ctx.Log.Logf("test", "LOCKED_10000") + ctx.Log.Logf("test", "LOCKED_%d", n) } func TestLock(t *testing.T) { diff --git a/log.go b/log.go index 034a760..60d09e3 100644 --- a/log.go +++ b/log.go @@ -50,7 +50,7 @@ func (logger * ConsoleLogger) SetComponents(components []string) error { return false } - for c, _ := range(logger.loggers) { + for c := range(logger.loggers) { if component_enabled(c) == false { delete(logger.loggers, c) } diff --git a/node.go b/node.go index adfa72f..2227ff5 100644 --- a/node.go +++ b/node.go @@ -78,7 +78,7 @@ type Node struct { Key ed25519.PrivateKey `gv:"key"` ID NodeID Type NodeType `gv:"type"` - // TODO: move each extension to it's own db key, and extend changes to notify which extension was changed + Extensions map[ExtType]Extension // Channel for this node to receive messages from the Context @@ -90,7 +90,6 @@ type Node struct { Active atomic.Bool - // TODO: enhance WriteNode to write SignalQueue to a different key, and use writeSignalQueue to decide whether or not to update it writeSignalQueue bool SignalQueue []QueuedSignal NextSignal *QueuedSignal @@ -344,7 +343,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { } if (len(changes) != 0) || node.writeSignalQueue { - write_err := WriteNodeChanges(ctx, node, changes) + write_err := ctx.DB.WriteNodeChanges(ctx, node, changes) if write_err != nil { return write_err } @@ -456,7 +455,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size writeSignalQueue: false, } - err = WriteNodeInit(ctx, node) + err = ctx.DB.WriteNodeInit(ctx, node) if err != nil { return nil, err } diff --git a/serialize.go b/serialize.go index ece705d..4973261 100644 --- a/serialize.go +++ b/serialize.go @@ -194,7 +194,7 @@ func Deserialize[T any](ctx *Context, data []byte) (T, error) { if err != nil { return zero, err } else if len(left) != 0 { - return zero, fmt.Errorf("%d bytes left after deserializing %+v", len(left), value) + return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value) } else if value.Type() != reflect_type { return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type) }