diff --git a/context.go b/context.go index 7259265..ed9f712 100644 --- a/context.go +++ b/context.go @@ -8,43 +8,35 @@ import ( //Function to load an extension from bytes type ExtensionLoadFunc func(*Context, []byte) (Extension, error) -// Information about a loaded extension + +// Information about a registered extension type ExtensionInfo struct { Load ExtensionLoadFunc Type ExtType Data interface{} } -// Information about a loaded node type +// Information about a registered node type type NodeInfo struct { Type NodeType Extensions []ExtType } -// A Context is all the data needed to run a graphvent +// A Context stores all the data to run a graphvent process type Context struct { // DB is the database connection used to load and write nodes DB * badger.DB - // Log is an interface used to record events happening + // Logging interface Log Logger - // A mapping between type hashes and their corresponding extension definitions + // Map between database extension hashes and the registered info Extensions map[uint64]ExtensionInfo - // A mapping between type hashes and their corresponding node definitions + // Map between database type hashes and the registered info Types map[uint64]NodeInfo - // All loaded Nodes + // Routing map to all the nodes local to this context Nodes map[NodeID]*Node } -func (ctx *Context) ExtByType(ext_type ExtType) *ExtensionInfo { - type_hash := ext_type.Hash() - ext, ok := ctx.Extensions[type_hash] - if ok == true { - return &ext - } else { - return nil - } -} - +// Register a NodeType to the context, with the list of extensions it requires func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { type_hash := node_type.Hash() _, exists := ctx.Types[type_hash] @@ -94,6 +86,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun return nil } +// Route a Signal to dest. Currently only local context routing is supported func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { target, exists := ctx.Nodes[dest] if exists == false { @@ -110,7 +103,7 @@ func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { return nil } -// Create a new Context with all the library content added +// Create a new Context with the base library content added func NewContext(db * badger.DB, log Logger) (*Context, error) { ctx := &Context{ DB: db, diff --git a/node.go b/node.go index c35e96a..0244499 100644 --- a/node.go +++ b/node.go @@ -1,7 +1,6 @@ package graphvent import ( - "sync" "time" "reflect" "github.com/google/uuid" @@ -10,17 +9,15 @@ import ( "encoding/binary" "encoding/json" "crypto/sha512" - "crypto/ecdsa" - "crypto/elliptic" + "sync/atomic" ) -// IDs are how nodes are uniquely identified, and can be serialized for the database +// A NodeID uniquely identifies a Node type NodeID uuid.UUID func (id NodeID) MarshalJSON() ([]byte, error) { str := id.String() return json.Marshal(&str) } - func (id *NodeID) UnmarshalJSON(bytes []byte) error { var id_str string err := json.Unmarshal(bytes, &id_str) @@ -32,6 +29,7 @@ func (id *NodeID) UnmarshalJSON(bytes []byte) error { return err } +// Base NodeID, used as a special value var ZeroUUID = uuid.UUID{} var ZeroID = NodeID(ZeroUUID) @@ -44,12 +42,14 @@ func (id NodeID) String() string { return (uuid.UUID)(id).String() } +// Create an ID from a fixed length byte array // Ignore the error since we're enforcing 16 byte length at compile time func IDFromBytes(bytes [16]byte) NodeID { id, _ := uuid.FromBytes(bytes[:]) return NodeID(id) } +// Parse an ID from a string func ParseID(str string) (NodeID, error) { id_uuid, err := uuid.Parse(str) if err != nil { @@ -58,45 +58,43 @@ func ParseID(str string) (NodeID, error) { return NodeID(id_uuid), nil } -func KeyID(pub *ecdsa.PublicKey) NodeID { - ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y) - str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3) - return NodeID(str) -} - // Generate a random NodeID func RandID() NodeID { return NodeID(uuid.New()) } +// A Serializable has a type that can be used to map to it, and a function to serialize the current state type Serializable[I comparable] interface { Type() I Serialize() ([]byte, error) } +// Extensions are data attached to nodes that process signals type Extension interface { Serializable[ExtType] Process(context *Context, source NodeID, node *Node, signal Signal) } +// A QueuedSignal is a Signal that has been Queued to trigger at a set time type QueuedSignal struct { Signal Signal Time time.Time } +// Default message channel size for nodes const NODE_MSG_CHAN_DEFAULT = 1024 -// Nodes represent an addressible group of extensions +// Nodes represent a group of extensions that can be collectively addressed type Node struct { ID NodeID Type NodeType - Lock sync.RWMutex Extensions map[ExtType]Extension + // Channel for this node to receive messages from the Context MsgChan chan Msg + // Channel for this node to process delayed signals TimeoutChan <-chan time.Time - LoopLock sync.Mutex - Active bool + Active atomic.Bool SignalQueue []QueuedSignal NextSignal *QueuedSignal @@ -146,10 +144,10 @@ type Msg struct { // Main Loop for Threads, starts a write context, so cannot be called from a write or read context func NodeLoop(ctx *Context, node *Node) error { - node.LoopLock.Lock() - defer node.LoopLock.Unlock() - - node.Active = true + started := node.Active.CompareAndSwap(false, true) + if started == false { + return fmt.Errorf("%s is already started, will not start again", node.ID) + } for true { var signal Signal var source NodeID @@ -176,6 +174,11 @@ func NodeLoop(ctx *Context, node *Node) error { } node.Process(ctx, source, signal) } + + stopped := node.Active.CompareAndSwap(true, false) + if stopped == false { + panic("BAD_STATE: stopping already stopped node") + } return nil } @@ -194,8 +197,9 @@ func GetCtx[T Extension, C any](ctx *Context) (C, error) { var zero T var zero_ctx C ext_type := zero.Type() - ext_info := ctx.ExtByType(ext_type) - if ext_info == nil { + type_hash := ext_type.Hash() + ext_info, ok := ctx.Extensions[type_hash] + if ok == false { return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type) } diff --git a/policy.go b/policy.go index 5c80318..f2e098d 100644 --- a/policy.go +++ b/policy.go @@ -267,7 +267,10 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { policies := make([]Policy, len(j.Policies)) i := 0 - acl_ctx := ctx.ExtByType(ACLExtType).Data.(*ACLExtContext) + acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx) + if err != nil { + return nil, err + } for name, ser := range(j.Policies) { policy_def, exists := acl_ctx.Types[PolicyType(name)] if exists == false {