graphvent/node.go

1065 lines
29 KiB
Go

2023-07-09 14:30:30 -06:00
package graphvent
import (
"time"
"errors"
2023-07-25 21:43:15 -06:00
"reflect"
2023-07-09 14:30:30 -06:00
"github.com/google/uuid"
badger "github.com/dgraph-io/badger/v3"
"fmt"
"encoding/binary"
2023-07-20 23:19:10 -06:00
"encoding/json"
"sync/atomic"
"crypto"
2023-08-06 12:47:47 -06:00
"crypto/ed25519"
"crypto/sha512"
"crypto/rand"
"crypto/x509"
2023-07-09 14:30:30 -06:00
)
const (
// Magic first four bytes of serialized DB content, stored big endian
NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function
NODE_DB_HEADER_LEN = 32
EXTENSION_DB_HEADER_LEN = 16
2023-08-10 23:43:10 -06:00
QSIGNAL_DB_HEADER_LEN = 24
POLICY_DB_HEADER_LEN = 16
)
var (
// Base NodeID, used as a special value
ZeroUUID = uuid.UUID{}
ZeroID = NodeID(ZeroUUID)
)
// A NodeID uniquely identifies a Node
type NodeID uuid.UUID
2023-07-30 10:09:04 -06:00
func (id *NodeID) MarshalJSON() ([]byte, error) {
str := ""
if id != nil {
str = id.String()
}
2023-07-24 17:07:27 -06:00
return json.Marshal(&str)
}
func (id *NodeID) UnmarshalJSON(bytes []byte) error {
var id_str string
err := json.Unmarshal(bytes, &id_str)
if err != nil {
return err
}
*id, err = ParseID(id_str)
return err
}
2023-07-09 14:30:30 -06:00
func (id NodeID) Serialize() []byte {
ser, _ := (uuid.UUID)(id).MarshalBinary()
return ser
}
func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
// Create an ID from a fixed length byte array
2023-07-25 21:43:15 -06:00
// Ignore the error since we're enforcing 16 byte length at compile time
func IDFromBytes(bytes [16]byte) NodeID {
id, _ := uuid.FromBytes(bytes[:])
return NodeID(id)
}
// Parse an ID from a string
func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str)
if err != nil {
return NodeID{}, err
}
return NodeID(id_uuid), nil
2023-07-09 14:30:30 -06:00
}
2023-07-10 22:31:43 -06:00
// Generate a random NodeID
2023-07-09 14:30:30 -06:00
func RandID() NodeID {
return NodeID(uuid.New())
2023-07-09 14:30:30 -06:00
}
// A Serializable has a type that can be used to map to it, and a function to serialize` the current state
2023-07-25 21:43:15 -06:00
type Serializable[I comparable] interface {
2023-08-01 20:55:15 -06:00
Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
2023-07-25 21:43:15 -06:00
Type() I
2023-07-09 14:30:30 -06:00
}
// Extensions are data attached to nodes that process signals
2023-07-25 21:43:15 -06:00
type Extension interface {
Serializable[ExtType]
Field(string)interface{}
Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages
2023-07-20 23:19:10 -06:00
}
// A QueuedSignal is a Signal that has been Queued to trigger at a set time
type QueuedSignal struct {
2023-07-30 01:29:15 -06:00
Signal
time.Time
}
2023-08-10 23:43:10 -06:00
type PendingACL struct {
Counter int
TimeoutID uuid.UUID
Action Tree
Principal NodeID
Messages Messages
Responses []Signal
Signal Signal
Source NodeID
}
type PendingSignal struct {
Policy PolicyType
Found bool
ID uuid.UUID
}
// Default message channel size for nodes
// Nodes represent a group of extensions that can be collectively addressed
2023-07-25 21:43:15 -06:00
type Node struct {
2023-08-06 12:47:47 -06:00
Key ed25519.PrivateKey
2023-07-25 21:43:15 -06:00
ID NodeID
2023-07-26 00:18:11 -06:00
Type NodeType
Extensions map[ExtType]Extension
Policies map[PolicyType]Policy
2023-08-10 23:43:10 -06:00
PendingACLs map[uuid.UUID]PendingACL
PendingSignals map[uuid.UUID]PendingSignal
// Channel for this node to receive messages from the Context
MsgChan chan *Message
// Size of MsgChan
BufferSize uint32
// Channel for this node to process delayed signals
TimeoutChan <-chan time.Time
Active atomic.Bool
SignalQueue []QueuedSignal
NextSignal *QueuedSignal
}
2023-08-10 23:43:10 -06:00
type RuleResult int
const (
Allow RuleResult = iota
Deny
Pending
)
func (node *Node) Allows(principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) {
pends := map[PolicyType]Messages{}
for policy_type, policy := range(node.Policies) {
msgs, resp := policy.Allows(principal_id, action, node)
if resp == Allow {
return nil, Allow
} else if resp == Pending {
pends[policy_type] = msgs
}
}
2023-08-10 23:43:10 -06:00
if len(pends) != 0 {
return pends, Pending
}
return nil, Deny
}
func (node *Node) QueueSignal(time time.Time, signal Signal) {
node.SignalQueue = append(node.SignalQueue, QueuedSignal{signal, time})
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
}
2023-08-10 23:43:10 -06:00
func (node *Node) DequeueSignal(id uuid.UUID) error {
idx := -1
for i, q := range(node.SignalQueue) {
if q.Signal.ID() == id {
idx = i
break
}
}
if idx == -1 {
return fmt.Errorf("%s is not in SignalQueue", id)
}
node.SignalQueue[idx] = node.SignalQueue[len(node.SignalQueue)-1]
node.SignalQueue = node.SignalQueue[:len(node.SignalQueue)-1]
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
2023-08-10 23:43:10 -06:00
return nil
}
func (node *Node) ClearSignalQueue() {
node.SignalQueue = []QueuedSignal{}
node.NextSignal = nil
node.TimeoutChan = nil
}
func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) {
var soonest_signal *QueuedSignal
var soonest_time time.Time
2023-07-30 11:29:58 -06:00
for i, signal := range(signals) {
2023-07-30 11:25:03 -06:00
if signal.Time.Compare(soonest_time) == -1 || soonest_signal == nil {
2023-07-30 11:29:58 -06:00
soonest_signal = &signals[i]
soonest_time = signal.Time
}
}
if soonest_signal != nil {
2023-07-30 11:12:47 -06:00
return soonest_signal, time.After(time.Until(soonest_signal.Time))
} else {
return nil, nil
}
}
func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_START: %s", node.ID)
err := nodeLoop(ctx, node)
if err != nil {
panic(err)
}
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
}
func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} {
exts := map[ExtType]map[string]interface{}{}
for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{}
for _, req := range(field_reqs) {
ext, exists := node.Extensions[ext_type]
if exists == false {
fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
} else {
fields[req] = ext.Field(req)
}
}
exts[ext_type] = fields
}
return exts
}
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func nodeLoop(ctx *Context, node *Node) error {
started := node.Active.CompareAndSwap(false, true)
if started == false {
return fmt.Errorf("%s is already started, will not start again", node.ID)
}
2023-08-06 12:47:47 -06:00
// Perform startup actions
node.Process(ctx, ZeroID, &StartSignal)
for true {
var signal Signal
var source NodeID
select {
case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
ser, err := msg.Signal.Serialize()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal)
continue
}
sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...)
sig_data = append(sig_data, ser...)
validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature)
if validated == false {
ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg)
continue
}
2023-08-10 23:43:10 -06:00
princ_id := KeyID(msg.Principal)
if princ_id != node.ID {
pends, resp := node.Allows(princ_id, msg.Signal.Permission())
if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), source)
ctx.Send(msgs)
continue
} else if resp == Pending {
ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends)
timeout_signal := NewACLTimeoutSignal(msg.Signal.ID())
node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal)
msgs := Messages{}
for policy_type, sigs := range(pends) {
for _, m := range(sigs) {
msgs = append(msgs, m)
node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()}
}
}
node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source}
ctx.Send(msgs)
continue
} else if resp == Allow {
ctx.Log.Logf("policy", "SIGNAL_POLICY_ALLOW: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
}
} else {
ctx.Log.Logf("policy", "SIGNAL_POLICY_SELF: %s - %s", node.ID, msg.Signal.Permission())
}
signal = msg.Signal
source = msg.Source
case <-node.TimeoutChan:
signal = node.NextSignal.Signal
source = node.ID
2023-07-30 11:07:41 -06:00
t := node.NextSignal.Time
2023-07-30 01:29:15 -06:00
i := -1
for j, queued := range(node.SignalQueue) {
2023-08-10 23:43:10 -06:00
if queued.Signal.ID() == node.NextSignal.Signal.ID() {
2023-07-30 01:29:15 -06:00
i = j
break
}
}
if i == -1 {
panic("node.NextSignal not in node.SignalQueue")
}
l := len(node.SignalQueue)
node.SignalQueue[i] = node.SignalQueue[l-1]
node.SignalQueue = node.SignalQueue[:(l-1)]
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
2023-07-30 11:02:22 -06:00
if node.NextSignal == nil {
2023-08-10 23:43:10 -06:00
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan)
2023-07-30 11:02:22 -06:00
} else {
2023-08-10 23:43:10 -06:00
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time)
2023-07-30 11:02:22 -06:00
}
}
2023-08-10 23:43:10 -06:00
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
info, waiting := node.PendingSignals[signal.ReqID()]
if waiting == true {
if info.Found == false {
info.Found = true
node.PendingSignals[signal.ReqID()] = info
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID]
if exists == true {
req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, signal)
// TODO: call the right policy ParseResponse to check if the updated state passes the ACL check
allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal)
if allowed == Allow {
ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal)
signal = req_info.Signal
source = req_info.Source
err := node.DequeueSignal(req_info.TimeoutID)
if err != nil {
panic("dequeued a passed signal")
}
delete(node.PendingACLs, info.ID)
} else if req_info.Counter == 0 {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source)
err := ctx.Send(msgs)
if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err)
}
err = node.DequeueSignal(req_info.TimeoutID)
if err != nil {
panic("dequeued a passed signal")
}
delete(node.PendingACLs, info.ID)
} else {
node.PendingACLs[info.ID] = req_info
continue
}
}
}
}
2023-08-06 12:47:47 -06:00
// Handle special signal types
if signal.Type() == StopSignalType {
msgs := Messages{}
2023-08-10 23:43:10 -06:00
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source)
ctx.Send(msgs)
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID))
break
} else if signal.Type() == ReadSignalType {
read_signal, ok := signal.(*ReadSignal)
if ok == false {
2023-08-10 23:43:10 -06:00
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else {
result := node.ReadFields(read_signal.Extensions)
msgs := Messages{}
2023-08-10 23:43:10 -06:00
msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source)
ctx.Send(msgs)
}
}
node.Process(ctx, source, signal)
2023-07-31 16:37:32 -06:00
// assume that processing a signal means that this nodes state changed
// TODO: remove a lot of database writes by only writing when things change,
// so need to have Process return whether or not state changed
err := WriteNode(ctx, node)
if err != nil {
panic(err)
}
}
stopped := node.Active.CompareAndSwap(true, false)
if stopped == false {
panic("BAD_STATE: stopping already stopped node")
}
return nil
}
type Message struct {
Source NodeID
Dest NodeID
Principal ed25519.PublicKey
Signal Signal
Signature []byte
}
type Messages []*Message
2023-08-10 23:43:10 -06:00
func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages {
msg, err := NewMessage(dest, source, principal, signal)
if err != nil {
2023-08-10 23:43:10 -06:00
panic(err)
} else {
msgs = append(msgs, msg)
}
return msgs
}
func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) {
ser, err := signal.Serialize()
if err != nil {
return nil, err
}
sig_data := append(dest.Serialize(), source.Serialize()...)
sig_data = append(sig_data, ser...)
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, err
}
return &Message{
Dest: dest,
Source: source,
Principal: principal.Public().(ed25519.PublicKey),
Signal: signal,
Signature: sig,
}, nil
}
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type())
messages := Messages{}
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
//TODO: add extension and node info to log
resp := ext.Process(ctx, node, source, signal)
if resp != nil {
messages = append(messages, resp...)
}
}
return ctx.Send(messages)
}
2023-07-26 11:56:10 -06:00
func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T
var zero_ctx C
ext_type := zero.Type()
type_hash := Hash(ext_type)
ext_info, ok := ctx.Extensions[type_hash]
if ok == false {
2023-07-26 11:56:10 -06:00
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext_ctx, ok := ext_info.Data.(C)
if ok == false {
return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero))
}
return ext_ctx, nil
}
2023-07-25 21:43:15 -06:00
func GetExt[T Extension](node *Node) (T, error) {
var zero T
ext_type := zero.Type()
ext, exists := node.Extensions[ext_type]
2023-07-25 21:43:15 -06:00
if exists == false {
2023-07-26 11:56:10 -06:00
return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions)
2023-07-24 16:04:56 -06:00
}
2023-07-24 01:41:47 -06:00
2023-07-25 21:43:15 -06:00
ret, ok := ext.(T)
if ok == false {
return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
}
2023-07-20 23:19:10 -06:00
2023-07-25 21:43:15 -06:00
return ret, nil
2023-07-20 23:19:10 -06:00
}
2023-07-25 21:43:15 -06:00
func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.Extensions))
2023-08-06 12:47:47 -06:00
qsignals := make([]QSignalDB, len(node.SignalQueue))
policies := make([]PolicyDB, len(node.Policies))
2023-08-06 12:47:47 -06:00
key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil {
return nil, err
}
2023-07-25 21:43:15 -06:00
node_db := NodeDB{
Header: NodeDBHeader{
Magic: NODE_DB_MAGIC,
TypeHash: Hash(node.Type),
KeyLength: uint32(len(key_bytes)),
BufferSize: node.BufferSize,
2023-07-25 21:43:15 -06:00
NumExtensions: uint32(len(extensions)),
NumPolicies: uint32(len(policies)),
NumQueuedSignals: uint32(len(node.SignalQueue)),
2023-07-25 21:43:15 -06:00
},
Extensions: extensions,
Policies: policies,
2023-08-06 12:47:47 -06:00
QueuedSignals: qsignals,
KeyBytes: key_bytes,
2023-07-25 21:43:15 -06:00
}
2023-07-09 14:30:30 -06:00
2023-07-24 16:04:56 -06:00
i := 0
for ext_type, info := range(node.Extensions) {
2023-07-25 21:43:15 -06:00
ser, err := info.Serialize()
if err != nil {
return nil, err
}
node_db.Extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: Hash(ext_type),
2023-07-25 21:43:15 -06:00
Length: uint64(len(ser)),
},
Data: ser,
}
2023-07-24 16:04:56 -06:00
i += 1
}
2023-08-06 12:47:47 -06:00
for i, qsignal := range(node.SignalQueue) {
ser, err := qsignal.Signal.Serialize()
if err != nil {
return nil, err
}
node_db.QueuedSignals[i] = QSignalDB{
QSignalDBHeader{
qsignal.Time,
Hash(qsignal.Signal.Type()),
uint64(len(ser)),
},
ser,
}
}
i = 0
for _, policy := range(node.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
node_db.Policies[i] = PolicyDB{
PolicyDBHeader{
Hash(policy.Type()),
uint64(len(ser)),
},
ser,
}
}
2023-07-25 21:43:15 -06:00
return node_db.Serialize(), nil
}
2023-08-06 12:47:47 -06:00
func KeyID(pub ed25519.PublicKey) NodeID {
str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID(str)
}
// Create a new node in memory and start it's event loop
// TODO: Change panics to errors
func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) *Node {
var err error
2023-08-06 12:47:47 -06:00
var public ed25519.PublicKey
if key == nil {
2023-08-06 12:47:47 -06:00
public, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
2023-08-06 12:47:47 -06:00
} else {
public = key.Public().(ed25519.PublicKey)
}
2023-08-06 12:47:47 -06:00
id := KeyID(public)
_, exists := ctx.Node(id)
if exists == true {
panic("Attempted to create an existing node")
}
def, exists := ctx.Types[Hash(node_type)]
2023-07-27 11:33:11 -06:00
if exists == false {
panic("Node type %s not registered in Context")
}
ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) {
_, exists := ext_map[ext.Type()]
if exists == true {
panic("Cannot add the same extension to a node twice")
}
ext_map[ext.Type()] = ext
}
for _, required_ext := range(def.Extensions) {
_, exists := ext_map[required_ext]
if exists == false {
panic(fmt.Sprintf("%s requires %s", node_type, required_ext))
}
}
if policies == nil {
policies = map[PolicyType]Policy{}
}
node := &Node{
Key: key,
2023-07-25 21:43:15 -06:00
ID: id,
2023-07-26 00:18:11 -06:00
Type: node_type,
2023-07-27 11:33:11 -06:00
Extensions: ext_map,
Policies: policies,
2023-08-10 23:43:10 -06:00
//TODO serialize/deserialize these
PendingACLs: map[uuid.UUID]PendingACL{},
PendingSignals: map[uuid.UUID]PendingSignal{},
MsgChan: make(chan *Message, buffer_size),
BufferSize: buffer_size,
SignalQueue: []QueuedSignal{},
2023-07-25 21:43:15 -06:00
}
ctx.AddNode(id, node)
err = WriteNode(ctx, node)
if err != nil {
panic(err)
}
node.Process(ctx, ZeroID, &NewSignal)
2023-08-06 12:47:47 -06:00
go runNode(ctx, node)
2023-07-27 11:33:11 -06:00
return node
2023-07-09 14:30:30 -06:00
}
type PolicyDBHeader struct {
TypeHash uint64
Length uint64
}
type PolicyDB struct {
Header PolicyDBHeader
Data []byte
}
2023-07-25 21:43:15 -06:00
type QSignalDBHeader struct {
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
2023-07-10 22:31:43 -06:00
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
2023-07-25 21:43:15 -06:00
type NodeDBHeader struct {
2023-07-09 14:30:30 -06:00
Magic uint32
2023-07-25 21:43:15 -06:00
NumExtensions uint32
NumPolicies uint32
NumQueuedSignals uint32
BufferSize uint32
KeyLength uint32
2023-07-26 00:18:11 -06:00
TypeHash uint64
2023-07-25 21:43:15 -06:00
}
type NodeDB struct {
Header NodeDBHeader
Extensions []ExtensionDB
Policies []PolicyDB
QueuedSignals []QSignalDB
KeyBytes []byte
2023-07-25 21:43:15 -06:00
}
//TODO: add size safety checks
func NewNodeDB(data []byte) (NodeDB, error) {
var zero NodeDB
ptr := 0
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
num_policies := binary.BigEndian.Uint32(data[8:12])
num_queued_signals := binary.BigEndian.Uint32(data[12:16])
buffer_size := binary.BigEndian.Uint32(data[16:20])
key_length := binary.BigEndian.Uint32(data[20:24])
node_type_hash := binary.BigEndian.Uint64(data[24:32])
2023-07-25 21:43:15 -06:00
ptr += NODE_DB_HEADER_LEN
if magic != NODE_DB_MAGIC {
return zero, fmt.Errorf("header has incorrect magic 0x%x", magic)
}
key_bytes := make([]byte, key_length)
n := copy(key_bytes, data[ptr:(ptr+int(key_length))])
if n != int(key_length) {
return zero, fmt.Errorf("not enough key bytes: %d", n)
}
ptr += int(key_length)
2023-07-25 21:43:15 -06:00
extensions := make([]ExtensionDB, num_extensions)
for i, _ := range(extensions) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(EXTENSION_DB_HEADER_LEN)
data_end := data_start + length
ext_data := cur[data_start:data_end]
extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: type_hash,
Length: length,
},
Data: ext_data,
}
ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
policies := make([]PolicyDB, num_policies)
for i, _ := range(policies) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(POLICY_DB_HEADER_LEN)
data_end := data_start + length
policy_data := cur[data_start:data_end]
policies[i] = PolicyDB{
PolicyDBHeader{
type_hash,
length,
},
policy_data,
}
ptr += int(POLICY_DB_HEADER_LEN + length)
}
2023-08-06 12:47:47 -06:00
queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) {
2023-08-06 12:47:47 -06:00
cur := data[ptr:]
// TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented
// Right now causes segfault because any saved signal is loaded as nil
2023-08-10 23:43:10 -06:00
unix_milli := binary.BigEndian.Uint64(cur[0:8])
type_hash := binary.BigEndian.Uint64(cur[8:16])
signal_size := binary.BigEndian.Uint64(cur[16:24])
2023-08-06 12:47:47 -06:00
signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)]
queued_signals[i] = QSignalDB{
QSignalDBHeader{
time.UnixMilli(int64(unix_milli)),
type_hash,
signal_size,
},
signal_data,
}
ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size)
}
2023-07-25 21:43:15 -06:00
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
2023-07-26 00:18:11 -06:00
TypeHash: node_type_hash,
BufferSize: buffer_size,
KeyLength: key_length,
2023-07-25 21:43:15 -06:00
NumExtensions: num_extensions,
NumQueuedSignals: num_queued_signals,
2023-07-25 21:43:15 -06:00
},
KeyBytes: key_bytes,
2023-07-25 21:43:15 -06:00
Extensions: extensions,
QueuedSignals: queued_signals,
2023-07-25 21:43:15 -06:00
}, nil
2023-07-09 14:30:30 -06:00
}
2023-07-25 21:43:15 -06:00
func (header NodeDBHeader) Serialize() []byte {
2023-07-09 14:30:30 -06:00
if header.Magic != NODE_DB_MAGIC {
panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic))
}
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
2023-07-25 21:43:15 -06:00
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies)
binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[16:20], header.BufferSize)
binary.BigEndian.PutUint32(ret[20:24], header.KeyLength)
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
2023-07-09 14:30:30 -06:00
return ret
}
2023-07-25 21:43:15 -06:00
func (node NodeDB) Serialize() []byte {
ser := node.Header.Serialize()
ser = append(ser, node.KeyBytes...)
2023-07-25 21:43:15 -06:00
for _, extension := range(node.Extensions) {
ser = append(ser, extension.Serialize()...)
2023-07-09 14:30:30 -06:00
}
2023-08-06 12:47:47 -06:00
for _, qsignal := range(node.QueuedSignals) {
ser = append(ser, qsignal.Serialize()...)
}
2023-07-25 21:43:15 -06:00
return ser
}
2023-08-06 12:47:47 -06:00
func (header QSignalDBHeader) Serialize() []byte {
ret := make([]byte, QSIGNAL_DB_HEADER_LEN)
2023-08-10 23:43:10 -06:00
binary.BigEndian.PutUint64(ret[0:8], uint64(header.Time.UnixMilli()))
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash)
binary.BigEndian.PutUint64(ret[16:24], header.Length)
2023-08-06 12:47:47 -06:00
return ret
}
func (qsignal QSignalDB) Serialize() []byte {
header_bytes := qsignal.Header.Serialize()
return append(header_bytes, qsignal.Data...)
}
2023-07-25 21:43:15 -06:00
func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[8:16], header.Length)
return ret
}
func (extension ExtensionDB) Serialize() []byte {
header_bytes := extension.Header.Serialize()
return append(header_bytes, extension.Data...)
}
2023-07-27 15:49:21 -06:00
// Write a node to the database
func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
bytes, err := node.Serialize()
if err != nil {
return err
}
id_bytes := node.ID.Serialize()
ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes)
return ctx.DB.Update(func(txn *badger.Txn) error {
return txn.Set(id_bytes, bytes)
})
}
2023-07-09 14:30:30 -06:00
2023-07-25 21:43:15 -06:00
func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOADING_NODE: %s", id)
2023-07-09 14:30:30 -06:00
var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error {
id_bytes := id.Serialize()
ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes)
item, err := txn.Get(id_bytes)
2023-07-09 14:30:30 -06:00
if err != nil {
return err
}
return item.Value(func(val []byte) error {
bytes = append([]byte{}, val...)
return nil
})
})
if errors.Is(err, badger.ErrKeyNotFound) {
return nil, NodeNotFoundError
}else if err != nil {
2023-07-25 21:43:15 -06:00
return nil, err
2023-07-09 14:30:30 -06:00
}
2023-07-25 21:43:15 -06:00
// Parse the bytes from the DB
node_db, err := NewNodeDB(bytes)
if err != nil {
return nil, err
2023-07-09 14:30:30 -06:00
}
policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies)
for _, policy_db := range(node_db.Policies) {
policy_info, exists := ctx.Policies[policy_db.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash)
}
policy, err := policy_info.Load(ctx, policy_db.Data)
if err != nil {
return nil, err
}
policies[policy_info.Type] = policy
}
2023-08-06 12:47:47 -06:00
key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil {
return nil, err
}
2023-08-06 12:47:47 -06:00
var key ed25519.PrivateKey
switch k := key_raw.(type) {
case ed25519.PrivateKey:
key = k
default:
return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k))
}
2023-08-06 12:47:47 -06:00
key_id := KeyID(key.Public().(ed25519.PublicKey))
if key_id != id {
return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id)
}
2023-07-26 00:18:11 -06:00
node_type, known := ctx.Types[node_db.Header.TypeHash]
if known == false {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
}
2023-08-06 12:47:47 -06:00
signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals)
for i, qsignal := range(node_db.QueuedSignals) {
sig_info, exists := ctx.Signals[qsignal.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash)
}
signal, err := sig_info.Load(ctx, qsignal.Data)
if err != nil {
return nil, err
}
2023-08-10 23:43:10 -06:00
signal_queue[i] = QueuedSignal{signal, qsignal.Header.Time}
2023-08-06 12:47:47 -06:00
}
next_signal, timeout_chan := SoonestSignal(signal_queue)
node := &Node{
Key: key,
ID: key_id,
2023-07-27 11:33:11 -06:00
Type: node_type.Type,
Extensions: map[ExtType]Extension{},
Policies: policies,
MsgChan: make(chan *Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan,
2023-08-06 12:47:47 -06:00
SignalQueue: signal_queue,
NextSignal: next_signal,
2023-07-27 11:33:11 -06:00
}
ctx.AddNode(id, node)
2023-07-09 14:30:30 -06:00
2023-07-26 11:56:10 -06:00
found_extensions := []ExtType{}
2023-07-25 21:43:15 -06:00
// Parse each of the extensions from the db
for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
def, known := ctx.Extensions[type_hash]
if known == false {
return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash)
2023-07-09 14:30:30 -06:00
}
ctx.Log.Logf("db", "DB_EXTENSION_LOADING: %s/%s", id, def.Type)
2023-07-25 21:43:15 -06:00
extension, err := def.Load(ctx, ext_db.Data)
2023-07-09 14:30:30 -06:00
if err != nil {
return nil, err
}
node.Extensions[def.Type] = extension
2023-07-26 11:56:10 -06:00
found_extensions = append(found_extensions, def.Type)
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s/%s - %+v", id, def.Type, extension)
2023-07-26 11:56:10 -06:00
}
missing_extensions := []ExtType{}
for _, ext := range(node_type.Extensions) {
found := false
for _, found_ext := range(found_extensions) {
if found_ext == ext {
found = true
break
}
}
if found == false {
missing_extensions = append(missing_extensions, ext)
}
}
if len(missing_extensions) > 0 {
return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions)
}
extra_extensions := []ExtType{}
for _, found_ext := range(found_extensions) {
found := false
for _, ext := range(node_type.Extensions) {
if ext == found_ext {
found = true
break
}
}
if found == false {
extra_extensions = append(extra_extensions, found_ext)
}
}
if len(extra_extensions) > 0 {
ctx.Log.Logf("db", "DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions)
2023-07-09 14:30:30 -06:00
}
2023-07-25 21:43:15 -06:00
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
go runNode(ctx, node)
2023-07-09 14:30:30 -06:00
return node, nil
}
func IDMap[S any, T map[NodeID]S](m T)map[string]S {
ret := map[string]S{}
for id, val := range(m) {
ret[id.String()] = val
}
return ret
}
func LoadIDMap[S any, T map[string]S](m T)(map[NodeID]S, error) {
ret := map[NodeID]S{}
for str, val := range(m) {
id, err := ParseID(str)
if err != nil {
return nil, err
}
ret[id] = val
}
return ret, nil
}