Added comments and switched from a bool/mutex combo to prevent double-start to an atomic bool

gql_cataclysm
noah metz 2023-07-27 16:06:56 -06:00
parent 7965f8fbe6
commit 6b375245df
3 changed files with 40 additions and 40 deletions

@ -8,43 +8,35 @@ import (
//Function to load an extension from bytes
type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
// Information about a loaded extension
// Information about a registered extension
type ExtensionInfo struct {
Load ExtensionLoadFunc
Type ExtType
Data interface{}
}
// Information about a loaded node type
// Information about a registered node type
type NodeInfo struct {
Type NodeType
Extensions []ExtType
}
// A Context is all the data needed to run a graphvent
// A Context stores all the data to run a graphvent process
type Context struct {
// DB is the database connection used to load and write nodes
DB * badger.DB
// Log is an interface used to record events happening
// Logging interface
Log Logger
// A mapping between type hashes and their corresponding extension definitions
// Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo
// A mapping between type hashes and their corresponding node definitions
// Map between database type hashes and the registered info
Types map[uint64]NodeInfo
// All loaded Nodes
// Routing map to all the nodes local to this context
Nodes map[NodeID]*Node
}
func (ctx *Context) ExtByType(ext_type ExtType) *ExtensionInfo {
type_hash := ext_type.Hash()
ext, ok := ctx.Extensions[type_hash]
if ok == true {
return &ext
} else {
return nil
}
}
// Register a NodeType to the context, with the list of extensions it requires
func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
type_hash := node_type.Hash()
_, exists := ctx.Types[type_hash]
@ -94,6 +86,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
return nil
}
// Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error {
target, exists := ctx.Nodes[dest]
if exists == false {
@ -110,7 +103,7 @@ func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error {
return nil
}
// Create a new Context with all the library content added
// Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{
DB: db,

@ -1,7 +1,6 @@
package graphvent
import (
"sync"
"time"
"reflect"
"github.com/google/uuid"
@ -10,17 +9,15 @@ import (
"encoding/binary"
"encoding/json"
"crypto/sha512"
"crypto/ecdsa"
"crypto/elliptic"
"sync/atomic"
)
// IDs are how nodes are uniquely identified, and can be serialized for the database
// A NodeID uniquely identifies a Node
type NodeID uuid.UUID
func (id NodeID) MarshalJSON() ([]byte, error) {
str := id.String()
return json.Marshal(&str)
}
func (id *NodeID) UnmarshalJSON(bytes []byte) error {
var id_str string
err := json.Unmarshal(bytes, &id_str)
@ -32,6 +29,7 @@ func (id *NodeID) UnmarshalJSON(bytes []byte) error {
return err
}
// Base NodeID, used as a special value
var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID)
@ -44,12 +42,14 @@ func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
// Create an ID from a fixed length byte array
// Ignore the error since we're enforcing 16 byte length at compile time
func IDFromBytes(bytes [16]byte) NodeID {
id, _ := uuid.FromBytes(bytes[:])
return NodeID(id)
}
// Parse an ID from a string
func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str)
if err != nil {
@ -58,45 +58,43 @@ func ParseID(str string) (NodeID, error) {
return NodeID(id_uuid), nil
}
func KeyID(pub *ecdsa.PublicKey) NodeID {
ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3)
return NodeID(str)
}
// Generate a random NodeID
func RandID() NodeID {
return NodeID(uuid.New())
}
// A Serializable has a type that can be used to map to it, and a function to serialize the current state
type Serializable[I comparable] interface {
Type() I
Serialize() ([]byte, error)
}
// Extensions are data attached to nodes that process signals
type Extension interface {
Serializable[ExtType]
Process(context *Context, source NodeID, node *Node, signal Signal)
}
// A QueuedSignal is a Signal that has been Queued to trigger at a set time
type QueuedSignal struct {
Signal Signal
Time time.Time
}
// Default message channel size for nodes
const NODE_MSG_CHAN_DEFAULT = 1024
// Nodes represent an addressible group of extensions
// Nodes represent a group of extensions that can be collectively addressed
type Node struct {
ID NodeID
Type NodeType
Lock sync.RWMutex
Extensions map[ExtType]Extension
// Channel for this node to receive messages from the Context
MsgChan chan Msg
// Channel for this node to process delayed signals
TimeoutChan <-chan time.Time
LoopLock sync.Mutex
Active bool
Active atomic.Bool
SignalQueue []QueuedSignal
NextSignal *QueuedSignal
@ -146,10 +144,10 @@ type Msg struct {
// Main Loop for Threads, starts a write context, so cannot be called from a write or read context
func NodeLoop(ctx *Context, node *Node) error {
node.LoopLock.Lock()
defer node.LoopLock.Unlock()
node.Active = true
started := node.Active.CompareAndSwap(false, true)
if started == false {
return fmt.Errorf("%s is already started, will not start again", node.ID)
}
for true {
var signal Signal
var source NodeID
@ -176,6 +174,11 @@ func NodeLoop(ctx *Context, node *Node) error {
}
node.Process(ctx, source, signal)
}
stopped := node.Active.CompareAndSwap(true, false)
if stopped == false {
panic("BAD_STATE: stopping already stopped node")
}
return nil
}
@ -194,8 +197,9 @@ func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T
var zero_ctx C
ext_type := zero.Type()
ext_info := ctx.ExtByType(ext_type)
if ext_info == nil {
type_hash := ext_type.Hash()
ext_info, ok := ctx.Extensions[type_hash]
if ok == false {
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type)
}

@ -267,7 +267,10 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
policies := make([]Policy, len(j.Policies))
i := 0
acl_ctx := ctx.ExtByType(ACLExtType).Data.(*ACLExtContext)
acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil {
return nil, err
}
for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false {