Added comments and switched from a bool/mutex combo to prevent double-start to an atomic bool

gql_cataclysm
noah metz 2023-07-27 16:06:56 -06:00
parent 7965f8fbe6
commit 6b375245df
3 changed files with 40 additions and 40 deletions

@ -8,43 +8,35 @@ import (
//Function to load an extension from bytes //Function to load an extension from bytes
type ExtensionLoadFunc func(*Context, []byte) (Extension, error) type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
// Information about a loaded extension
// Information about a registered extension
type ExtensionInfo struct { type ExtensionInfo struct {
Load ExtensionLoadFunc Load ExtensionLoadFunc
Type ExtType Type ExtType
Data interface{} Data interface{}
} }
// Information about a loaded node type // Information about a registered node type
type NodeInfo struct { type NodeInfo struct {
Type NodeType Type NodeType
Extensions []ExtType Extensions []ExtType
} }
// A Context is all the data needed to run a graphvent // A Context stores all the data to run a graphvent process
type Context struct { type Context struct {
// DB is the database connection used to load and write nodes // DB is the database connection used to load and write nodes
DB * badger.DB DB * badger.DB
// Log is an interface used to record events happening // Logging interface
Log Logger Log Logger
// A mapping between type hashes and their corresponding extension definitions // Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo Extensions map[uint64]ExtensionInfo
// A mapping between type hashes and their corresponding node definitions // Map between database type hashes and the registered info
Types map[uint64]NodeInfo Types map[uint64]NodeInfo
// All loaded Nodes // Routing map to all the nodes local to this context
Nodes map[NodeID]*Node Nodes map[NodeID]*Node
} }
func (ctx *Context) ExtByType(ext_type ExtType) *ExtensionInfo { // Register a NodeType to the context, with the list of extensions it requires
type_hash := ext_type.Hash()
ext, ok := ctx.Extensions[type_hash]
if ok == true {
return &ext
} else {
return nil
}
}
func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
type_hash := node_type.Hash() type_hash := node_type.Hash()
_, exists := ctx.Types[type_hash] _, exists := ctx.Types[type_hash]
@ -94,6 +86,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
return nil return nil
} }
// Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error {
target, exists := ctx.Nodes[dest] target, exists := ctx.Nodes[dest]
if exists == false { if exists == false {
@ -110,7 +103,7 @@ func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error {
return nil return nil
} }
// Create a new Context with all the library content added // Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) { func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{ ctx := &Context{
DB: db, DB: db,

@ -1,7 +1,6 @@
package graphvent package graphvent
import ( import (
"sync"
"time" "time"
"reflect" "reflect"
"github.com/google/uuid" "github.com/google/uuid"
@ -10,17 +9,15 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"crypto/sha512" "crypto/sha512"
"crypto/ecdsa" "sync/atomic"
"crypto/elliptic"
) )
// IDs are how nodes are uniquely identified, and can be serialized for the database // A NodeID uniquely identifies a Node
type NodeID uuid.UUID type NodeID uuid.UUID
func (id NodeID) MarshalJSON() ([]byte, error) { func (id NodeID) MarshalJSON() ([]byte, error) {
str := id.String() str := id.String()
return json.Marshal(&str) return json.Marshal(&str)
} }
func (id *NodeID) UnmarshalJSON(bytes []byte) error { func (id *NodeID) UnmarshalJSON(bytes []byte) error {
var id_str string var id_str string
err := json.Unmarshal(bytes, &id_str) err := json.Unmarshal(bytes, &id_str)
@ -32,6 +29,7 @@ func (id *NodeID) UnmarshalJSON(bytes []byte) error {
return err return err
} }
// Base NodeID, used as a special value
var ZeroUUID = uuid.UUID{} var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID) var ZeroID = NodeID(ZeroUUID)
@ -44,12 +42,14 @@ func (id NodeID) String() string {
return (uuid.UUID)(id).String() return (uuid.UUID)(id).String()
} }
// Create an ID from a fixed length byte array
// Ignore the error since we're enforcing 16 byte length at compile time // Ignore the error since we're enforcing 16 byte length at compile time
func IDFromBytes(bytes [16]byte) NodeID { func IDFromBytes(bytes [16]byte) NodeID {
id, _ := uuid.FromBytes(bytes[:]) id, _ := uuid.FromBytes(bytes[:])
return NodeID(id) return NodeID(id)
} }
// Parse an ID from a string
func ParseID(str string) (NodeID, error) { func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str) id_uuid, err := uuid.Parse(str)
if err != nil { if err != nil {
@ -58,45 +58,43 @@ func ParseID(str string) (NodeID, error) {
return NodeID(id_uuid), nil return NodeID(id_uuid), nil
} }
func KeyID(pub *ecdsa.PublicKey) NodeID {
ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3)
return NodeID(str)
}
// Generate a random NodeID // Generate a random NodeID
func RandID() NodeID { func RandID() NodeID {
return NodeID(uuid.New()) return NodeID(uuid.New())
} }
// A Serializable has a type that can be used to map to it, and a function to serialize the current state
type Serializable[I comparable] interface { type Serializable[I comparable] interface {
Type() I Type() I
Serialize() ([]byte, error) Serialize() ([]byte, error)
} }
// Extensions are data attached to nodes that process signals
type Extension interface { type Extension interface {
Serializable[ExtType] Serializable[ExtType]
Process(context *Context, source NodeID, node *Node, signal Signal) Process(context *Context, source NodeID, node *Node, signal Signal)
} }
// A QueuedSignal is a Signal that has been Queued to trigger at a set time
type QueuedSignal struct { type QueuedSignal struct {
Signal Signal Signal Signal
Time time.Time Time time.Time
} }
// Default message channel size for nodes
const NODE_MSG_CHAN_DEFAULT = 1024 const NODE_MSG_CHAN_DEFAULT = 1024
// Nodes represent an addressible group of extensions // Nodes represent a group of extensions that can be collectively addressed
type Node struct { type Node struct {
ID NodeID ID NodeID
Type NodeType Type NodeType
Lock sync.RWMutex
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
// Channel for this node to receive messages from the Context
MsgChan chan Msg MsgChan chan Msg
// Channel for this node to process delayed signals
TimeoutChan <-chan time.Time TimeoutChan <-chan time.Time
LoopLock sync.Mutex Active atomic.Bool
Active bool
SignalQueue []QueuedSignal SignalQueue []QueuedSignal
NextSignal *QueuedSignal NextSignal *QueuedSignal
@ -146,10 +144,10 @@ type Msg struct {
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context // Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func NodeLoop(ctx *Context, node *Node) error { func NodeLoop(ctx *Context, node *Node) error {
node.LoopLock.Lock() started := node.Active.CompareAndSwap(false, true)
defer node.LoopLock.Unlock() if started == false {
return fmt.Errorf("%s is already started, will not start again", node.ID)
node.Active = true }
for true { for true {
var signal Signal var signal Signal
var source NodeID var source NodeID
@ -176,6 +174,11 @@ func NodeLoop(ctx *Context, node *Node) error {
} }
node.Process(ctx, source, signal) node.Process(ctx, source, signal)
} }
stopped := node.Active.CompareAndSwap(true, false)
if stopped == false {
panic("BAD_STATE: stopping already stopped node")
}
return nil return nil
} }
@ -194,8 +197,9 @@ func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T var zero T
var zero_ctx C var zero_ctx C
ext_type := zero.Type() ext_type := zero.Type()
ext_info := ctx.ExtByType(ext_type) type_hash := ext_type.Hash()
if ext_info == nil { ext_info, ok := ctx.Extensions[type_hash]
if ok == false {
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type) return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type)
} }

@ -267,7 +267,10 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
policies := make([]Policy, len(j.Policies)) policies := make([]Policy, len(j.Policies))
i := 0 i := 0
acl_ctx := ctx.ExtByType(ACLExtType).Data.(*ACLExtContext) acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil {
return nil, err
}
for name, ser := range(j.Policies) { for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)] policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false { if exists == false {