Moved db from badger to an interface

master
noah metz 2024-03-30 14:42:06 -07:00
parent 66d5e3f260
commit b2d84b2453
9 changed files with 94 additions and 57 deletions

@ -17,8 +17,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/ast"
badger "github.com/dgraph-io/badger/v3"
) )
var ( var (
@ -82,7 +80,7 @@ type InterfaceInfo struct {
type Context struct { type Context struct {
// DB is the database connection used to load and write nodes // DB is the database connection used to load and write nodes
DB * badger.DB DB Database
// Logging interface // Logging interface
Log Logger Log Logger
@ -870,7 +868,7 @@ func (ctx *Context) Stop() {
} }
func (ctx *Context) Load(id NodeID) (*Node, error) { func (ctx *Context) Load(id NodeID) (*Node, error) {
node, err := LoadNode(ctx, id) node, err := ctx.DB.LoadNode(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -973,7 +971,7 @@ func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{
} }
// Create a new Context with the base library content added // Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) { func NewContext(db Database, log Logger) (*Context, error) {
uuid.EnableRandPool() uuid.EnableRandPool()
ctx := &Context{ ctx := &Context{

90
db.go

@ -3,100 +3,127 @@ package graphvent
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"reflect" "reflect"
"sync"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
const NODE_BUFFER_SIZE = 1000000 type Database interface {
WriteNodeInit(*Context, *Node) error
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
LoadNode(*Context, NodeID) (*Node, error)
}
const WRITE_BUFFER_SIZE = 1000000
type BadgerDB struct {
*badger.DB
sync.Mutex
buffer [WRITE_BUFFER_SIZE]byte
}
func WriteNodeInit(ctx *Context, node *Node) error { func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
if node == nil { if node == nil {
return fmt.Errorf("Cannot serialize nil *Node") return fmt.Errorf("Cannot serialize nil *Node")
} }
buffer := [NODE_BUFFER_SIZE]byte{} return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
return ctx.DB.Update(func(tx *badger.Txn) error {
// Get the base key bytes // Get the base key bytes
id_ser, err := node.ID.MarshalBinary() id_ser, err := node.ID.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
cur := 0
// Write Node value // Write Node value
written, err := Serialize(ctx, node, buffer[:]) written, err := Serialize(ctx, node, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
err = tx.Set(id_ser, buffer[:written])
err = tx.Set(id_ser, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// Write empty signal queue // Write empty signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
written, err = Serialize(ctx, node.SignalQueue, buffer[:]) written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
err = tx.Set(sigqueue_id, buffer[:written])
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// Write node extension list // Write node extension list
ext_list := []ExtType{} ext_list := []ExtType{}
for ext_type := range(node.Extensions) { for ext_type := range(node.Extensions) {
ext_list = append(ext_list, ext_type) ext_list = append(ext_list, ext_type)
} }
written, err = Serialize(ctx, ext_list, buffer[:]) written, err = Serialize(ctx, ext_list, db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
ext_list_id := append(id_ser, []byte(" - EXTLIST")...) ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
err = tx.Set(ext_list_id, buffer[:written]) err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return err return err
} }
cur += written
// For each extension: // For each extension:
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
// Write each extension's current value // Write each extension's current value
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), buffer[:]) written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
err = tx.Set(ext_id, buffer[:written]) err = tx.Set(ext_id, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
} }
return nil return nil
}) })
} }
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
buffer := [NODE_BUFFER_SIZE]byte{} return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
return ctx.DB.Update(func(tx *badger.Txn) error {
// Get the base key bytes // Get the base key bytes
id_ser, err := node.ID.MarshalBinary() id_bytes := ([16]byte)(node.ID)
if err != nil {
return fmt.Errorf("Marshal ID error: %+w", err) cur := 0
}
// Write the signal queue if it needs to be written // Write the signal queue if it needs to be written
if node.writeSignalQueue { if node.writeSignalQueue {
node.writeSignalQueue = false node.writeSignalQueue = false
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
written, err := Serialize(ctx, node.SignalQueue, buffer[:]) written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil { if err != nil {
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err) return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
} }
err = tx.Set(sigqueue_id, buffer[:written]) err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err) return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
} }
cur += written
} }
// For each ext in changes // For each ext in changes
@ -106,24 +133,26 @@ func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) err
if exists == false { if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
} }
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), buffer[:]) written, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem(), db.buffer[cur:])
if err != nil { if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err) return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
} }
err = tx.Set(ext_id, buffer[:written]) err = tx.Set(ext_id, db.buffer[cur:cur+written])
if err != nil { if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err) return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
} }
cur += written
} }
return nil return nil
}) })
} }
func LoadNode(ctx *Context, id NodeID) (*Node, error) { func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
var node *Node = nil var node *Node = nil
err := ctx.DB.View(func(tx *badger.Txn) error {
err := db.View(func(tx *badger.Txn) error {
// Get the base key bytes // Get the base key bytes
id_ser, err := id.MarshalBinary() id_ser, err := id.MarshalBinary()
if err != nil { if err != nil {
@ -137,12 +166,13 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) {
} }
err = node_item.Value(func(val []byte) error { err = node_item.Value(func(val []byte) error {
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
node, err = Deserialize[*Node](ctx, val) node, err = Deserialize[*Node](ctx, val)
return err return err
}) })
if err != nil { if err != nil {
return nil return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
} }
// Get the signal queue // Get the signal queue
@ -211,6 +241,8 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} else if node == nil {
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
} }
return node, nil return node, nil

@ -202,7 +202,7 @@ func TestGQLQuery(t *testing.T) {
} }
func TestGQLDB(t *testing.T) { func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node"}) ctx := logTestContext(t, []string{"test", "db", "node", "serialize"})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)

@ -26,7 +26,9 @@ func logTestContext(t * testing.T, components []string) *Context {
t.Fatal(err) t.Fatal(err)
} }
ctx, err := NewContext(db, NewConsoleLogger(components)) ctx, err := NewContext(&BadgerDB{
DB: db,
}, NewConsoleLogger(components))
fatalErr(t, err) fatalErr(t, err)
return ctx return ctx

@ -143,8 +143,7 @@ func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source Node
ext.PendingOwner = nil ext.PendingOwner = nil
ext.ReqID = new(uuid.UUID) ext.ReqID = &signal.Id
*ext.ReqID = signal.Id
ext.State = Unlocking ext.State = Unlocking
for id := range(ext.Requirements) { for id := range(ext.Requirements) {
@ -175,22 +174,18 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
changes = append(changes, "state", "owner", "pending_owner") changes = append(changes, "state", "owner", "pending_owner")
ext.Owner = new(NodeID) ext.Owner = &source
*ext.Owner = source
ext.PendingOwner = new(NodeID) ext.PendingOwner = &source
*ext.PendingOwner = source
ext.State = Locked ext.State = Locked
messages = append(messages, SendMsg{source, NewSuccessSignal(signal.Id)}) messages = append(messages, SendMsg{source, NewSuccessSignal(signal.Id)})
} else { } else {
changes = append(changes, "state", "requirements", "waiting", "pending_owner") changes = append(changes, "state", "requirements", "waiting", "pending_owner")
ext.PendingOwner = new(NodeID) ext.PendingOwner = &source
*ext.PendingOwner = source
ext.ReqID = new(uuid.UUID) ext.ReqID = &signal.Id
*ext.ReqID = signal.Id
ext.State = Locking ext.State = Locking
for id := range(ext.Requirements) { for id := range(ext.Requirements) {
@ -313,8 +308,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
changes = append(changes, "state", "owner", "req_id") changes = append(changes, "state", "owner", "req_id")
ext.State = Locked ext.State = Locked
ext.Owner = new(NodeID) ext.Owner = ext.PendingOwner
*ext.Owner = *ext.PendingOwner
messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)})
ext.ReqID = nil ext.ReqID = nil

@ -42,7 +42,19 @@ func TestLink(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
} }
func Test10Lock(t *testing.T) {
testLockN(t, 10)
}
func Test1000Lock(t *testing.T) {
testLockN(t, 1000)
}
func Test10000Lock(t *testing.T) { func Test10000Lock(t *testing.T) {
testLockN(t, 10000)
}
func testLockN(t *testing.T, n int) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
@ -51,12 +63,12 @@ func Test10000Lock(t *testing.T) {
return l return l
} }
reqs := make([]NodeID, 10000) reqs := make([]NodeID, n)
for i := range(reqs) { for i := range(reqs) {
new_lockable := NewLockable() new_lockable := NewLockable()
reqs[i] = new_lockable.ID reqs[i] = new_lockable.ID
} }
ctx.Log.Logf("test", "CREATED_10000") ctx.Log.Logf("test", "CREATED_%d", n)
listener := NewListenerExt(50000) listener := NewListenerExt(50000)
node, err := NewNode(ctx, nil, "LockableNode", 500000, listener, NewLockableExt(reqs)) node, err := NewNode(ctx, nil, "LockableNode", 500000, listener, NewLockableExt(reqs))
@ -75,7 +87,7 @@ func Test10000Lock(t *testing.T) {
t.Fatalf("Unexpected response to lock - %s", resp) t.Fatalf("Unexpected response to lock - %s", resp)
} }
ctx.Log.Logf("test", "LOCKED_10000") ctx.Log.Logf("test", "LOCKED_%d", n)
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {

@ -50,7 +50,7 @@ func (logger * ConsoleLogger) SetComponents(components []string) error {
return false return false
} }
for c, _ := range(logger.loggers) { for c := range(logger.loggers) {
if component_enabled(c) == false { if component_enabled(c) == false {
delete(logger.loggers, c) delete(logger.loggers, c)
} }

@ -78,7 +78,7 @@ type Node struct {
Key ed25519.PrivateKey `gv:"key"` Key ed25519.PrivateKey `gv:"key"`
ID NodeID ID NodeID
Type NodeType `gv:"type"` Type NodeType `gv:"type"`
// TODO: move each extension to it's own db key, and extend changes to notify which extension was changed
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
// Channel for this node to receive messages from the Context // Channel for this node to receive messages from the Context
@ -90,7 +90,6 @@ type Node struct {
Active atomic.Bool Active atomic.Bool
// TODO: enhance WriteNode to write SignalQueue to a different key, and use writeSignalQueue to decide whether or not to update it
writeSignalQueue bool writeSignalQueue bool
SignalQueue []QueuedSignal SignalQueue []QueuedSignal
NextSignal *QueuedSignal NextSignal *QueuedSignal
@ -344,7 +343,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
} }
if (len(changes) != 0) || node.writeSignalQueue { if (len(changes) != 0) || node.writeSignalQueue {
write_err := WriteNodeChanges(ctx, node, changes) write_err := ctx.DB.WriteNodeChanges(ctx, node, changes)
if write_err != nil { if write_err != nil {
return write_err return write_err
} }
@ -456,7 +455,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size
writeSignalQueue: false, writeSignalQueue: false,
} }
err = WriteNodeInit(ctx, node) err = ctx.DB.WriteNodeInit(ctx, node)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -194,7 +194,7 @@ func Deserialize[T any](ctx *Context, data []byte) (T, error) {
if err != nil { if err != nil {
return zero, err return zero, err
} else if len(left) != 0 { } else if len(left) != 0 {
return zero, fmt.Errorf("%d bytes left after deserializing %+v", len(left), value) return zero, fmt.Errorf("%d/%d bytes left after deserializing %+v", len(left), len(data), value)
} else if value.Type() != reflect_type { } else if value.Type() != reflect_type {
return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type) return zero, fmt.Errorf("Deserialized type %s does not match %s", value.Type(), reflect_type)
} }