graphvent/db.go

285 lines
7.9 KiB
Go

package graphvent
import (
"encoding/binary"
"fmt"
"reflect"
"sync"
badger "github.com/dgraph-io/badger/v3"
)
type Database interface {
WriteNodeInit(*Context, *Node) error
WriteNodeChanges(*Context, *Node, map[ExtType]Changes) error
LoadNode(*Context, NodeID) (*Node, error)
}
const WRITE_BUFFER_SIZE = 1000000
type BadgerDB struct {
*badger.DB
sync.Mutex
buffer [WRITE_BUFFER_SIZE]byte
}
func (db *BadgerDB) WriteNodeInit(ctx *Context, node *Node) error {
if node == nil {
return fmt.Errorf("Cannot serialize nil *Node")
}
return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes
id_ser, err := node.ID.MarshalBinary()
if err != nil {
return err
}
cur := 0
// Write Node value
written, err := Serialize(ctx, node, db.buffer[cur:])
if err != nil {
return err
}
err = tx.Set(id_ser, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// Write empty signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
written, err = Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil {
return err
}
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// Write node extension list
ext_list := []ExtType{}
for ext_type := range(node.Extensions) {
ext_list = append(ext_list, ext_type)
}
written, err = Serialize(ctx, ext_list, db.buffer[cur:])
if err != nil {
return err
}
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
err = tx.Set(ext_list_id, db.buffer[cur:cur+written])
if err != nil {
return err
}
cur += written
// For each extension:
for ext_type, ext := range(node.Extensions) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Cannot serialize node with unknown extension %s", reflect.TypeOf(ext))
}
ext_value := reflect.ValueOf(ext).Elem()
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
// Write each field to a seperate key
for field_tag, field_info := range(ext_info.Fields) {
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
}
}
return nil
})
}
func (db *BadgerDB) WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
return db.Update(func(tx *badger.Txn) error {
db.Lock()
defer db.Unlock()
// Get the base key bytes
id_bytes := ([16]byte)(node.ID)
cur := 0
// Write the signal queue if it needs to be written
if node.writeSignalQueue {
node.writeSignalQueue = false
sigqueue_id := append(id_bytes[:], []byte(" - SIGQUEUE")...)
written, err := Serialize(ctx, node.SignalQueue, db.buffer[cur:])
if err != nil {
return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err)
}
err = tx.Set(sigqueue_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err)
}
cur += written
}
// For each ext in changes
for ext_type, changes := range(changes) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext, exists := node.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
}
ext_id := binary.BigEndian.AppendUint64(id_bytes[:], uint64(ext_type))
ext_value := reflect.ValueOf(ext)
// Write each field
for _, tag := range(changes) {
field_info, exists := ext_info.Fields[tag]
if exists == false {
return fmt.Errorf("Cannot serialize field %s of extension %s, does not exist", tag, ext_type)
}
field_value := ext_value.FieldByIndex(field_info.Index)
field_id := make([]byte, len(ext_id) + 8)
tmp := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(tag))))
copy(field_id, tmp)
written, err := SerializeValue(ctx, field_value, db.buffer[cur:])
if err != nil {
return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err)
}
err = tx.Set(field_id, db.buffer[cur:cur+written])
if err != nil {
return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err)
}
cur += written
}
}
return nil
})
}
func (db *BadgerDB) LoadNode(ctx *Context, id NodeID) (*Node, error) {
var node *Node = nil
err := db.View(func(tx *badger.Txn) error {
// Get the base key bytes
id_ser, err := id.MarshalBinary()
if err != nil {
return fmt.Errorf("Failed to serialize node_id: %w", err)
}
// Get the node value
node_item, err := tx.Get(id_ser)
if err != nil {
return fmt.Errorf("Failed to get node_item: %w", NodeNotFoundError)
}
err = node_item.Value(func(val []byte) error {
ctx.Log.Logf("db", "DESERIALIZE_NODE(%d bytes): %+v", len(val), val)
node, err = Deserialize[*Node](ctx, val)
return err
})
if err != nil {
return fmt.Errorf("Failed to deserialize Node %s - %w", id, err)
}
// Get the signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_item, err := tx.Get(sigqueue_id)
if err != nil {
return fmt.Errorf("Failed to get sigqueue_id: %w", err)
}
err = sigqueue_item.Value(func(val []byte) error {
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
return err
})
if err != nil {
return fmt.Errorf("Failed to deserialize []QueuedSignal for %s: %w", id, err)
}
// Get the extension list
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
ext_list_item, err := tx.Get(ext_list_id)
if err != nil {
return err
}
var ext_list []ExtType
ext_list_item.Value(func(val []byte) error {
ext_list, err = Deserialize[[]ExtType](ctx, val)
return err
})
// Get the extensions
for _, ext_type := range(ext_list) {
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return fmt.Errorf("Extension %s not in context", ext_type)
}
ext := reflect.New(ext_info.Type)
for field_tag, field_info := range(ext_info.Fields) {
field_id := binary.BigEndian.AppendUint64(ext_id, uint64(GetFieldTag(string(field_tag))))
field_item, err := tx.Get(field_id)
if err != nil {
return fmt.Errorf("Failed to find key for %s:%s(%x) - %w", ext_type, field_tag, field_id, err)
}
err = field_item.Value(func(val []byte) error {
value, _, err := DeserializeValue(ctx, val, field_info.Type)
if err != nil {
return err
}
ext.Elem().FieldByIndex(field_info.Index).Set(value)
return nil
})
if err != nil {
return err
}
}
node.Extensions[ext_type] = ext.Interface().(Extension)
}
return nil
})
if err != nil {
return nil, err
} else if node == nil {
return nil, fmt.Errorf("Tried to return nil *Node from BadgerDB.LoadNode without error")
}
return node, nil
}