Cleanup and move away from capnp to custom TLV serialization

gql_cataclysm
noah metz 2023-08-31 19:50:32 -06:00
parent 7bed89701d
commit 4daec4d601
24 changed files with 1192 additions and 1606 deletions

3
.gitignore vendored

@ -1,12 +1,15 @@
# Ignore everything # Ignore everything
* *
!/go-capnp
# But not these files... # But not these files...
!/.gitignore !/.gitignore
!*.go !*.go
*.capnp.go
!go.sum !go.sum
!go.mod !go.mod
!*.capnp
!README.md !README.md
!LICENSE !LICENSE

@ -0,0 +1,9 @@
.PHONY: schema test
test: schema
clear && go test
schema: signal/signal.capnp.go
%.capnp.go: %.capnp
capnp compile -I ./go-capnp/std -ogo $<

@ -1,124 +1,127 @@
package graphvent package graphvent
import ( import (
badger "github.com/dgraph-io/badger/v3"
"fmt"
"sync"
"errors"
"runtime"
"crypto/sha512"
"crypto/ecdh" "crypto/ecdh"
"crypto/sha512"
"encoding/binary" "encoding/binary"
) "errors"
"fmt"
"reflect"
"runtime"
"sync"
// A Type can be Hashed by Hash badger "github.com/dgraph-io/badger/v3"
type TypeName interface { )
String() string
Prefix() string
}
// Hashed a Type to a uint64 func Hash(base string, name string) uint64 {
func Hash(t TypeName) uint64 { digest := append([]byte(base), 0x00)
hash := sha512.Sum512([]byte(fmt.Sprintf("%s%s", t.Prefix(), t.String()))) digest = append(digest, []byte(name)...)
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) hash := sha512.Sum512(digest)
return binary.BigEndian.Uint64(hash[0:8])
} }
// NodeType identifies the 'class' of a node type ExtType uint64
type NodeType string type NodeType uint64
func (node NodeType) Prefix() string { return "NODE: " } type SignalType uint64
func (node NodeType) String() string { return string(node) } type PolicyType uint64
type SerializedType uint64
// ExtType identifies an extension on a node
type ExtType string
func (ext ExtType) Prefix() string { return "EXTENSION: " }
func (ext ExtType) String() string { return string(ext) }
//Function to load an extension from bytes func NewExtType(name string) ExtType {
type ExtensionLoadFunc func(*Context,[]byte) (Extension, error) return ExtType(Hash(ExtTypeBase, name))
func LoadExtension[T any, E interface {
*T
Extension
}](ctx *Context, data []byte) (Extension, error) {
e := E(new(T))
err := e.Deserialize(ctx, data)
if err != nil {
return nil, err
} }
return e, nil func NewNodeType(name string) NodeType {
return NodeType(Hash(NodeTypeBase, name))
} }
type PolicyType string func NewSignalType(name string) SignalType {
func (policy PolicyType) Prefix() string { return "POLICY: " } return SignalType(Hash(SignalTypeBase, name))
func (policy PolicyType) String() string { return string(policy) }
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
} }
return p, nil
func NewPolicyType(name string) PolicyType {
return PolicyType(Hash(PolicyTypeBase, name))
} }
type PolicyInfo struct { func NewSerializedType(name string) SerializedType {
Load PolicyLoadFunc return SerializedType(Hash(SerializedTypeBase, name))
Type PolicyType
} }
// ExtType and NodeType constants
const ( const (
ListenerExtType = ExtType("LISTENER") ExtTypeBase = "ExtType"
LockableExtType = ExtType("LOCKABLE") NodeTypeBase = "NodeType"
GQLExtType = ExtType("GQL") SignalTypeBase = "SignalType"
GroupExtType = ExtType("GROUP") PolicyTypeBase = "PolicyType"
ECDHExtType = ExtType("ECDH") SerializedTypeBase = "SerializedType"
FieldNameBase = "FieldName"
GQLNodeType = NodeType("GQL")
) )
var ( var (
ListenerExtType = NewExtType("LISTENER")
LockableExtType = NewExtType("LOCKABLE")
GQLExtType = NewExtType("GQL")
GroupExtType = NewExtType("GROUP")
ECDHExtType = NewExtType("ECDH")
GQLNodeType = NewNodeType("GQL")
StopSignalType = NewSignalType("STOP")
CreateSignalType = NewSignalType("CREATE")
StartSignalType = NewSignalType("START")
ErrorSignalType = NewSignalType("ERROR")
StatusSignalType = NewSignalType("STATUS")
LinkSignalType = NewSignalType("LINK")
LockSignalType = NewSignalType("LOCK")
ReadSignalType = NewSignalType("READ")
ReadResultSignalType = NewSignalType("READ_RESULT")
ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT")
MemberOfPolicyType = NewPolicyType("USER_OF")
RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF")
PerNodePolicyType = NewPolicyType("PER_NODE")
AllNodesPolicyType = NewPolicyType("ALL_NODES")
StructType = NewSerializedType("struct")
SliceType = NewSerializedType("slice")
ArrayType = NewSerializedType("array")
PointerType = NewSerializedType("pointer")
MapType = NewSerializedType("map")
ErrorType = NewSerializedType("error")
ExtensionType = NewSerializedType("extension")
StringType = NewSerializedType("string")
NodeKeyType = NewSerializedType("node_key")
NodeNotFoundError = errors.New("Node not found in DB") NodeNotFoundError = errors.New("Node not found in DB")
ECDH = ecdh.X25519() ECDH = ecdh.X25519()
) )
type SignalLoadFunc func(*Context,[]byte) (Signal, error) type ExtensionInfo struct {
func LoadSignal[T any, S interface{ Type reflect.Type
*T Data interface{}
Signal
}](ctx *Context, data []byte) (Signal, error) {
s := S(new(T))
err := s.Deserialize(ctx, data)
if err != nil {
return nil, err
} }
return s, nil type NodeInfo struct {
Extensions []ExtType
} }
type SignalInfo struct { type TypeSerialize func(*Context,interface{}) ([]byte, error)
Load SignalLoadFunc type TypeDeserialize func(*Context,[]byte) (interface{}, error)
Type SignalType type TypeInfo struct {
Type reflect.Type
Serialize TypeSerialize
Deserialize TypeDeserialize
} }
// Information about a registered extension type Int int
type ExtensionInfo struct { func (i Int) MarshalBinary() ([]byte, error) {
// Function used to load extensions of this type from the database ret := make([]byte, 8)
Load ExtensionLoadFunc binary.BigEndian.PutUint64(ret, uint64(i))
Type ExtType return ret, nil
// Extra context data shared between nodes of this class
Data interface{}
} }
// Information about a registered node type type String string
type NodeInfo struct { func (str String) MarshalBinary() ([]byte, error) {
Type NodeType return []byte(str), nil
// Required extensions to be a valid node of this class
Extensions []ExtType
} }
// A Context stores all the data to run a graphvent process // A Context stores all the data to run a graphvent process
@ -128,101 +131,132 @@ type Context struct {
// Logging interface // Logging interface
Log Logger Log Logger
// Map between database extension hashes and the registered info // Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo Extensions map[ExtType]ExtensionInfo
ExtensionTypes map[reflect.Type]ExtType
// Map between databse policy hashes and the registered info // Map between databse policy hashes and the registered info
Policies map[uint64]PolicyInfo Policies map[PolicyType]reflect.Type
PolicyTypes map[reflect.Type]PolicyType
// Map between serialized signal hashes and the registered info // Map between serialized signal hashes and the registered info
Signals map[uint64]SignalInfo Signals map[SignalType]reflect.Type
SignalTypes map[reflect.Type]SignalType
// Map between database type hashes and the registered info // Map between database type hashes and the registered info
Types map[uint64]*NodeInfo Nodes map[NodeType]NodeInfo
// Map between go types and registered info
Types map[SerializedType]TypeInfo
TypeReflects map[reflect.Type]SerializedType
// Routing map to all the nodes local to this context // Routing map to all the nodes local to this context
NodesLock sync.RWMutex nodeMapLock sync.RWMutex
Nodes map[NodeID]*Node nodeMap map[NodeID]*Node
} }
// Register a NodeType to the context, with the list of extensions it requires // Register a NodeType to the context, with the list of extensions it requires
func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
type_hash := Hash(node_type) _, exists := ctx.Nodes[node_type]
_, exists := ctx.Types[type_hash]
if exists == true { if exists == true {
return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type) return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type)
} }
ext_found := map[ExtType]bool{} ext_found := map[ExtType]bool{}
for _, extension := range(extensions) { for _, extension := range(extensions) {
_, in_ctx := ctx.Extensions[Hash(extension)] _, in_ctx := ctx.Extensions[extension]
if in_ctx == false { if in_ctx == false {
return fmt.Errorf("Cannot register node type %s, required extension %s not in context", node_type, extension) return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension)
} }
_, duplicate := ext_found[extension] _, duplicate := ext_found[extension]
if duplicate == true { if duplicate == true {
return fmt.Errorf("Duplicate extension %s found in extension list", extension) return fmt.Errorf("Duplicate extension %+v found in extension list", extension)
} }
ext_found[extension] = true ext_found[extension] = true
} }
ctx.Types[type_hash] = &NodeInfo{ ctx.Nodes[node_type] = NodeInfo{
Type: node_type,
Extensions: extensions, Extensions: extensions,
} }
return nil return nil
} }
func RegisterSignal[T any, S interface { func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type PolicyType) error {
*T _, exists := ctx.Policies[policy_type]
Signal
}](ctx *Context, signal_type SignalType) error {
type_hash := Hash(signal_type)
_, exists := ctx.Signals[type_hash]
if exists == true { if exists == true {
return fmt.Errorf("Cannot register signal of type %s, type already exists in context", signal_type) return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type)
} }
ctx.Signals[type_hash] = SignalInfo{ ctx.Policies[policy_type] = reflect_type
Load: LoadSignal[T, S], ctx.PolicyTypes[reflect_type] = policy_type
Type: signal_type, return nil
} }
func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalType) error {
_, exists := ctx.Signals[signal_type]
if exists == true {
return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type)
}
ctx.Signals[signal_type] = reflect_type
ctx.SignalTypes[reflect_type] = signal_type
return nil return nil
} }
// Add a node to a context, returns an error if the def is invalid or already exists in the context // Add a node to a context, returns an error if the def is invalid or already exists in the context
func RegisterExtension[T any, E interface{ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType, data interface{}) error {
*T _, exists := ctx.Extensions[ext_type]
Extension
}](ctx *Context, data interface{}) error {
var zero E
ext_type := zero.Type()
type_hash := Hash(ext_type)
_, exists := ctx.Extensions[type_hash]
if exists == true { if exists == true {
return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type) return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type)
} }
ctx.Extensions[type_hash] = ExtensionInfo{ ctx.Extensions[ext_type] = ExtensionInfo{
Load: LoadExtension[T,E], Type: reflect_type,
Type: ext_type,
Data: data, Data: data,
} }
ctx.ExtensionTypes[reflect_type] = ext_type
return nil
}
func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error {
_, exists := ctx.Types[ctx_type]
if exists == true {
return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type)
}
_, exists = ctx.TypeReflects[reflect_type]
if exists == true {
return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type)
}
if deserialize == nil {
return fmt.Errorf("Cannot register field without deserialize function")
}
if serialize == nil {
return fmt.Errorf("Cannot register field without serialize function")
}
ctx.Types[ctx_type] = TypeInfo{
Type: reflect_type,
Serialize: serialize,
Deserialize: deserialize,
}
ctx.TypeReflects[reflect_type] = ctx_type
return nil return nil
} }
func (ctx *Context) AddNode(id NodeID, node *Node) { func (ctx *Context) AddNode(id NodeID, node *Node) {
ctx.NodesLock.Lock() ctx.nodeMapLock.Lock()
ctx.Nodes[id] = node ctx.nodeMap[id] = node
ctx.NodesLock.Unlock() ctx.nodeMapLock.Unlock()
} }
func (ctx *Context) Node(id NodeID) (*Node, bool) { func (ctx *Context) Node(id NodeID) (*Node, bool) {
ctx.NodesLock.RLock() ctx.nodeMapLock.RLock()
node, exists := ctx.Nodes[id] node, exists := ctx.nodeMap[id]
ctx.NodesLock.RUnlock() ctx.nodeMapLock.RUnlock()
return node, exists return node, exists
} }
// Get a node from the context, or load from the database if not loaded // Get a node from the context, or load from the database if not loaded
func (ctx *Context) GetNode(id NodeID) (*Node, error) { func (ctx *Context) getNode(id NodeID) (*Node, error) {
target, exists := ctx.Node(id) target, exists := ctx.Node(id)
if exists == false { if exists == false {
@ -241,7 +275,7 @@ func (ctx *Context) Send(messages Messages) error {
if msg.Dest == ZeroID { if msg.Dest == ZeroID {
panic("Can't send to null ID") panic("Can't send to null ID")
} }
target, err := ctx.GetNode(msg.Dest) target, err := ctx.getNode(msg.Dest)
if err == nil { if err == nil {
select { select {
case target.MsgChan <- msg: case target.MsgChan <- msg:
@ -262,55 +296,311 @@ func (ctx *Context) Send(messages Messages) error {
return nil return nil
} }
type defaultKind struct {
Type SerializedType
Serialize func(interface{})([]byte, error)
Deserialize func([]byte)(interface{}, error)
}
var defaultKinds = map[reflect.Kind]defaultKind{
reflect.Int: {
Deserialize: func(data []byte)(interface{}, error){
if len(data) != 8 {
return nil, fmt.Errorf("invalid length: %d/8", len(data))
}
return int(binary.BigEndian.Uint64(data)), nil
},
Serialize: func(val interface{})([]byte, error){
i, ok := val.(int)
if ok == false {
return nil, fmt.Errorf("invalid type %+v", reflect.TypeOf(val))
} else {
bytes := make([]byte, 8)
binary.BigEndian.PutUint64(bytes, uint64(i))
return bytes, nil
}
},
},
}
type SerializedValue struct {
TypeStack []uint64
Data []byte
}
func (field SerializedValue) MarshalBinary() ([]byte, error) {
data := []byte{}
for _, t := range(field.TypeStack) {
t_ser := make([]byte, 8)
binary.BigEndian.PutUint64(t_ser, uint64(t))
data = append(data, t_ser...)
}
data = append(data, field.Data...)
return data, nil
}
func RecurseTypes(ctx *Context, t reflect.Type) ([]uint64, []reflect.Kind, error) {
var ctx_type uint64 = 0x00
ctype, exists := ctx.TypeReflects[t]
if exists == true {
ctx_type = uint64(ctype)
}
var new_types []uint64
var new_kinds []reflect.Kind
kind := t.Kind()
switch kind {
case reflect.Array:
if ctx_type == 0x00 {
ctx_type = uint64(ArrayType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Array)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Map:
if ctx_type == 0x00 {
ctx_type = uint64(MapType)
}
key_types, key_kinds, err := RecurseTypes(ctx, t.Key())
if err != nil {
return nil, nil, err
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, key_types...)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Map)
new_kinds = append(new_kinds, key_kinds...)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Slice:
if ctx_type == 0x00 {
ctx_type = uint64(SliceType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Slice)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Pointer:
if ctx_type == 0x00 {
ctx_type = uint64(PointerType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Pointer)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.String:
if ctx_type == 0x00 {
ctx_type = uint64(StringType)
}
new_types = append(new_types, ctx_type)
new_kinds = append(new_kinds, reflect.String)
default:
return nil, nil, fmt.Errorf("unhandled kind: %+v - %+v", kind, t)
}
return new_types, new_kinds, nil
}
func serializeValue(ctx *Context, kind_stack []reflect.Kind, value reflect.Value) ([]byte, error) {
kind := kind_stack[len(kind_stack) - 1]
switch kind {
default:
return nil, fmt.Errorf("unhandled kind: %+v", kind)
}
}
func SerializeValue(ctx *Context, value reflect.Value) (SerializedValue, error) {
if value.IsValid() == false {
return SerializedValue{}, fmt.Errorf("Cannot serialize invalid value: %+v", value)
}
type_stack, kind_stack, err := RecurseTypes(ctx, value.Type())
if err != nil {
return SerializedValue{}, err
}
bytes, err := serializeValue(ctx, kind_stack, value)
if err != nil {
return SerializedValue{}, err
}
return SerializedValue{
type_stack,
bytes,
}, nil
}
/*
default:
kind_def, handled := defaultKinds[kind]
if handled == false {
ctx_type, handled := ctx.TypeReflects[value.Type()]
if handled == false {
err = fmt.Errorf("%+v is not a handled reflect type", value.Type())
break
}
type_info, handled := ctx.Types[ctx_type]
if handled == false {
err = fmt.Errorf("%+v is not a handled reflect type(INTERNAL_ERROR)", value.Type())
break
}
field_ser, err := type_info.Serialize(ctx, value.Interface())
if err != nil {
err = fmt.Errorf(err.Error())
break
}
ret = SerializedValue{
[]uint64{uint64(ctx_type)},
field_ser,
}
}
field_ser, err := kind_def.Serialize(value.Interface())
if err != nil {
err = fmt.Errorf(err.Error())
} else {
ret = SerializedValue{
[]uint64{uint64(kind_def.Type)},
field_ser,
}
}
*/
func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) {
if ext == nil {
return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension")
}
ext_value := reflect.ValueOf(ext).Elem()
field := ext_value.FieldByName(field_name)
if field.IsValid() == false {
return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext)
} else {
return SerializeValue(ctx, field)
}
}
func SerializeSignal(ctx *Context, signal Signal, ctx_type SignalType) (SerializedValue, error) {
return SerializedValue{}, nil
}
func SerializeExtension(ctx *Context, ext Extension, ctx_type ExtType) (SerializedValue, error) {
if ext == nil {
return SerializedValue{}, fmt.Errorf("Cannot serialize nil Extension ")
}
ext_type := reflect.TypeOf(ext).Elem()
ext_value := reflect.ValueOf(ext).Elem()
m := map[string]SerializedValue{}
for _, field := range(reflect.VisibleFields(ext_type)) {
ext_tag, tagged_ext := field.Tag.Lookup("ext")
if tagged_ext == false {
continue
} else {
field_value := ext_value.FieldByIndex(field.Index)
var err error
m[ext_tag], err = SerializeValue(ctx, field_value)
if err != nil {
return SerializedValue{}, err
}
}
}
map_value := reflect.ValueOf(m)
map_ser, err := SerializeValue(ctx, map_value)
if err != nil {
return SerializedValue{}, err
}
return SerializedValue{
append([]uint64{uint64(ctx_type)}, map_ser.TypeStack...),
map_ser.Data,
}, nil
}
func DeserializeValue(ctx *Context, value SerializedValue) (interface{}, error) {
// TODO: do the opposite of SerializeValue.
// 1) Check the type to handle special types(array, list, map, pointer)
// 2) Check if the type is registered in the context, handle if so
// 3) Check if the type is a default type, handle if so
// 4) Return error if we don't know how to deserialize the type
return nil, fmt.Errorf("Undefined")
}
// Create a new Context with the base library content added // Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) { func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{ ctx := &Context{
DB: db, DB: db,
Log: log, Log: log,
Extensions: map[uint64]ExtensionInfo{}, Policies: map[PolicyType]reflect.Type{},
Types: map[uint64]*NodeInfo{}, PolicyTypes: map[reflect.Type]PolicyType{},
Signals: map[uint64]SignalInfo{}, Extensions: map[ExtType]ExtensionInfo{},
Nodes: map[NodeID]*Node{}, ExtensionTypes: map[reflect.Type]ExtType{},
Signals: map[SignalType]reflect.Type{},
SignalTypes: map[reflect.Type]SignalType{},
Nodes: map[NodeType]NodeInfo{},
nodeMap: map[NodeID]*Node{},
Types: map[SerializedType]TypeInfo{},
TypeReflects: map[reflect.Type]SerializedType{},
} }
var err error var err error
err = RegisterExtension[LockableExt,*LockableExt](ctx, nil) err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil) err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil) err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterExtension[GroupExt,*GroupExt](ctx, nil) gql_ctx := NewGQLExtContext()
err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
gql_ctx := NewGQLExtContext() err = ctx.RegisterSignal(reflect.TypeOf((*StopSignal)(nil)), StopSignalType)
err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx) if err != nil {
return nil, err
}
err = ctx.RegisterSignal(reflect.TypeOf((*CreateSignal)(nil)), CreateSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType) err = ctx.RegisterSignal(reflect.TypeOf((*StartSignal)(nil)), StartSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, NewSignalType) err = ctx.RegisterSignal(reflect.TypeOf((*ReadSignal)(nil)), ReadSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StartSignalType) err = ctx.RegisterSignal(reflect.TypeOf((*ReadResultSignal)(nil)), ReadResultSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -1,167 +0,0 @@
package graphvent
import (
"fmt"
"time"
"encoding/json"
"crypto/ecdsa"
"crypto/x509"
"crypto/ecdh"
)
type ECDHState struct {
ECKey *ecdh.PrivateKey
SharedSecret []byte
}
type ECDHStateJSON struct {
ECKey []byte `json:"ec_key"`
SharedSecret []byte `json:"shared_secret"`
}
func (state *ECDHState) MarshalJSON() ([]byte, error) {
var key_bytes []byte
var err error
if state.ECKey != nil {
key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey)
if err != nil {
return nil, err
}
}
return json.Marshal(&ECDHStateJSON{
ECKey: key_bytes,
SharedSecret: state.SharedSecret,
})
}
func (state *ECDHState) UnmarshalJSON(data []byte) error {
var j ECDHStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
state.SharedSecret = j.SharedSecret
if len(j.ECKey) == 0 {
state.ECKey = nil
} else {
tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey)
if err != nil {
return err
}
ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey)
if ok == false {
return fmt.Errorf("Parsed wrong key type from DB for ECDHState")
}
state.ECKey, err = ecdsa_key.ECDH()
if err != nil {
return err
}
}
return nil
}
type ECDHMap map[NodeID]ECDHState
func (m ECDHMap) MarshalJSON() ([]byte, error) {
tmp := map[string]ECDHState{}
for id, state := range(m) {
tmp[id.String()] = state
}
return json.Marshal(tmp)
}
type ECDHExt struct {
ECDHStates ECDHMap
}
func NewECDHExt() *ECDHExt {
return &ECDHExt{
ECDHStates: ECDHMap{},
}
}
func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} {
var zero T
field_func, ok := field_funcs[name]
if ok == false {
return fmt.Errorf("%s is not a field of %s", name, zero.Type())
}
return field_func(t)
}
func (ext *ECDHExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{
"ecdh_states": func(ext *ECDHExt) interface{} {
return ext.ECDHStates
},
})
}
func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) Messages {
source := KeyID(signal.EDDSA)
messages := Messages{}
switch signal.Str {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
messages = messages.Add(node.ID, node.Key, &resp, source)
} else {
log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source)
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source)
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
shared_secret, err := state.ECKey.ECDH(signal.ECDH)
if err == nil {
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
log.Logf("ecdh", "unknown echd state %s", signal.Str)
}
return messages
}
func (ext *ECDHExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
switch signal.Type() {
case ECDHSignalType:
sig := signal.(*ECDHSignal)
return ext.HandleECDHSignal(ctx.Log, node, sig)
}
return nil
}
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.Marshal(ext)
}
func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext)
}

@ -1,6 +1,8 @@
module github.com/mekkanized/graphvent module github.com/mekkanized/graphvent
go 1.20 go 1.21.0
replace github.com/mekkanized/graphvent/signal v0.0.0 => ./signal
require ( require (
github.com/dgraph-io/badger/v3 v3.2103.5 github.com/dgraph-io/badger/v3 v3.2103.5
@ -11,6 +13,7 @@ require (
) )
require ( require (
capnproto.org/go/capnp/v3 v3.0.0-alpha-29 // indirect
github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgraph-io/badger/v4 v4.1.0 // indirect github.com/dgraph-io/badger/v4 v4.1.0 // indirect
@ -28,8 +31,11 @@ require (
github.com/klauspost/compress v1.12.3 // indirect github.com/klauspost/compress v1.12.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mekkanized/graphvent/signal v0.0.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
go.opencensus.io v0.22.5 // indirect go.opencensus.io v0.22.5 // indirect
golang.org/x/net v0.7.0 // indirect golang.org/x/net v0.7.0 // indirect
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect
golang.org/x/sys v0.6.0 // indirect golang.org/x/sys v0.6.0 // indirect
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect
) )

@ -1,3 +1,5 @@
capnproto.org/go/capnp/v3 v3.0.0-alpha-29 h1:ICLhiy4Jmp0d7hLQO+HzFAVIft/oxpPAUPV8tqx+eUE=
capnproto.org/go/capnp/v3 v3.0.0-alpha-29/go.mod h1:+ysMHvOh1EWNOyorxJWs1omhRFiDoKxKkWQACp54jKM=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
@ -134,6 +136,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -171,3 +174,5 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 h1:yksDCGMVzyn3vlyf0GZ3huiF5FFaMGQpQ3UJvR0EoGA=
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5/go.mod h1:1LtNdPAs8WH+BTcQiZAOo2MIKD/5jyK/u7sZ9ZPe5SE=

@ -205,7 +205,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
if err != nil { if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes)
} }
auth_id := NodeID{auth_uuid} auth_id := NodeID(auth_uuid)
key_bytes, err := base64.StdEncoding.DecodeString(key_b64) key_bytes, err := base64.StdEncoding.DecodeString(key_b64)
if err != nil { if err != nil {
@ -234,7 +234,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
Ext: gql_ext, Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{}, Chans: map[uuid.UUID]chan Signal{},
Context: ctx, Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext),
Server: server, Server: server,
User: key_id, User: key_id,
Key: key, Key: key,
@ -270,7 +270,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
query := GQLPayload{} query := GQLPayload{}
json.Unmarshal(str, &query) json.Unmarshal(str, &query)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{ params := graphql.Params{
Schema: gql_context.Schema, Schema: gql_context.Schema,
@ -401,7 +401,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
} }
} else if msg.Type == "subscribe" { } else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{ params := graphql.Params{
Schema: gql_context.Schema, Schema: gql_context.Schema,
Context: req_ctx, Context: req_ctx,
@ -543,7 +543,7 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
return graphql.NewSchema(schemaConfig) return graphql.NewSchema(schemaConfig)
} }
func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, T)(interface{}, error)) error { func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, SerializedValue)(interface{}, error)) error {
if ctx == nil { if ctx == nil {
return fmt.Errorf("ctx is nil") return fmt.Errorf("ctx is nil")
} }
@ -561,21 +561,19 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st
return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) {
ext, exists := result.Result.Extensions[ext_type] ext, exists := result.Result.Extensions[ext_type]
if exists == false { if exists == false {
return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type)
} }
val_if, exists := ext[acl_name] val_ser, exists := ext[acl_name]
if exists == false { if exists == false {
return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type) return nil, fmt.Errorf("%s is not in the fields of %+v in the result", acl_name, ext_type)
} }
var zero T if val_ser.TypeStack[0] == uint64(ErrorType) {
val, ok := val_if.(T) return nil, fmt.Errorf(string(val_ser.Data))
if ok == false {
return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero))
} }
return resolve_fn(p, val) return resolve_fn(p, val_ser)
}) })
} }
@ -681,8 +679,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
for field_name, field := range(self_fields) { for field_name, field := range(self_fields) {
self_field := field self_field := field
err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, val interface{})(interface{}, error) { func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
@ -715,7 +713,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
for field_name, field := range(list_fields) { for field_name, field := range(list_fields) {
list_field := field list_field := field
resolve_fn := func(p graphql.ResolveParams, val interface{})(interface{}, error) { resolve_fn := func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
@ -736,7 +734,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
return nodes, nil return nodes, nil
} }
err := RegisterField(ctx, ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn) err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil { if err != nil {
return err return err
} }
@ -764,7 +762,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte
_, exists := ctx.NodeTypes[node_type] _, exists := ctx.NodeTypes[node_type]
if exists == true { if exists == true {
return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type) return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type)
} }
node_interfaces, err := GQLInterfaces(ctx, interface_names) node_interfaces, err := GQLInterfaces(ctx, interface_names)
@ -830,19 +828,20 @@ func NewGQLExtContext() *GQLExtContext {
panic(err) panic(err)
} }
err = RegisterField(&context, context.Interfaces["Node"].List, "Members", GroupExtType, "members", err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members",
func(p graphql.ResolveParams, val map[NodeID]string)(interface{}, error) { func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node_list := make([]NodeID, len(val)) /*node_list := make([]NodeID, len(val))
i := 0 i := 0
for id, _ := range(val) { for id, _ := range(val) {
node_list[i] = id node_list[i] = id
i += 1 i += 1
} }
*/
node_list := []NodeID{}
nodes, err := ResolveNodes(ctx, p, node_list) nodes, err := ResolveNodes(ctx, p, node_list)
if err != nil { if err != nil {
return nil, err return nil, err
@ -895,7 +894,7 @@ func NewGQLExtContext() *GQLExtContext {
panic(err) panic(err)
} }
err = RegisterField(&context, graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen string) (interface{}, error) { err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen SerializedValue) (interface{}, error) {
return listen, nil return listen, nil
}) })
if err != nil { if err != nil {
@ -1000,14 +999,6 @@ type GQLExt struct {
Listen string `json:"listen"` Listen string `json:"listen"`
} }
func (ext *GQLExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*GQLExt)interface{}{
"listen": func(ext *GQLExt) interface{} {
return ext.Listen
},
})
}
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock() ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id] response_chan, _ := ext.resolver_response[req_id]
@ -1036,11 +1027,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
messages := Messages{} switch sig := signal.(type) {
if signal.Type() == ErrorSignalType { case *ErrorSignal:
// TODO: Forward to resolver if waiting for it // TODO: Forward to resolver if waiting for it
sig := signal.(*ErrorSignal) response_chan := ext.FreeResponseChannel(sig.Header().ReqID)
response_chan := ext.FreeResponseChannel(sig.ReqID())
if response_chan != nil { if response_chan != nil {
select { select {
case response_chan <- sig: case response_chan <- sig:
@ -1052,9 +1042,8 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else { } else {
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig) ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
} }
} else if signal.Type() == ReadResultSignalType { case *ReadResultSignal:
sig := signal.(*ReadResultSignal) response_chan := ext.FindResponseChannel(sig.ReqID)
response_chan := ext.FindResponseChannel(sig.ReqID())
if response_chan != nil { if response_chan != nil {
select { select {
case response_chan <- sig: case response_chan <- sig:
@ -1065,23 +1054,23 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else { } else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
} }
} else if signal.Type() == StartSignalType { case *StartSignal:
ctx.Log.Logf("gql", "starting gql server %s", node.ID) ctx.Log.Logf("gql", "starting gql server %s", node.ID)
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) node.QueueSignal(time.Now(), NewStatusSignal(node.ID, "server_started"))
} else { } else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
} }
} }
return messages return nil
} }
func (ext *GQLExt) Type() ExtType { func (ext *GQLExt) Type() ExtType {
return GQLExtType return GQLExtType
} }
func (ext *GQLExt) Serialize() ([]byte, error) { func (ext *GQLExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext) return json.Marshal(ext)
} }

@ -36,7 +36,7 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy
return false return false
} }
node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)] node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
if exists == false { if exists == false {
return false return false
} else { } else {
@ -73,7 +73,7 @@ func NodeInterfaceResolveType(required_extensions []ExtType, default_type **grap
gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType] gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type) ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false { if exists == false {
node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)] node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
if exists == false { if exists == false {
return nil return nil
} else { } else {

@ -51,16 +51,16 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields) read_signal := NewReadSignal(ext_fields)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx.Server.ID, ctx.Key, read_signal, id) msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID)
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID] = response_chan
node_ids[read_signal.ID()] = id node_ids[read_signal.ID] = id
// TODO: Send all at once instead of createing Messages for each // TODO: Send all at once instead of createing Messages for each
err = ctx.Context.Send(msgs) err = ctx.Context.Send(msgs)
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID()) ctx.Ext.FreeResponseChannel(read_signal.ID)
return nil, err return nil, err
} }
} }
@ -68,8 +68,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
responses := []NodeResult{} responses := []NodeResult{}
for sig_id, response_chan := range(resp_channels) { for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout // Wait for the response, returning an error on timeout
response, err := WaitForSignal(response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ response, err := WaitForSignal(response_chan, time.Millisecond*100, func(sig *ReadResultSignal)bool{
return sig.ReqID() == sig_id return sig.ReqID == sig_id
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -81,6 +81,6 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) {
return Hash(node.Result.NodeType), nil return uint64(node.Result.NodeType), nil
}) })
} }

@ -21,7 +21,7 @@ import (
func TestGQLServer(t *testing.T) { func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"}) ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"})
TestNodeType := NodeType("TEST") TestNodeType := NewNodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
fatalErr(t, err) fatalErr(t, err)
@ -30,33 +30,33 @@ func TestGQLServer(t *testing.T) {
gql_id := KeyID(pub) gql_id := KeyID(pub)
group_policy_1 := NewAllNodesPolicy(Tree{ group_policy_1 := NewAllNodesPolicy(Tree{
ReadSignalType.String(): Tree{ uint64(ReadSignalType): Tree{
GroupExtType.String(): Tree{ uint64(GroupExtType): Tree{
"members": Tree{}, Hash(FieldNameBase, "members"): Tree{},
}, },
}, },
ReadResultSignalType.String(): nil, uint64(ReadResultSignalType): nil,
ErrorSignalType.String(): nil, uint64(ErrorSignalType): nil,
}) })
group_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ group_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{
gql_id: Tree{ gql_id: Tree{
LinkSignalType.String(): nil, uint64(LinkSignalType): nil,
LockSignalType.String(): nil, uint64(LockSignalType): nil,
StatusSignalType.String(): nil, uint64(StatusSignalType): nil,
ReadSignalType.String(): nil, uint64(ReadSignalType): nil,
}, },
}) })
user_policy_1 := NewAllNodesPolicy(Tree{ user_policy_1 := NewAllNodesPolicy(Tree{
ReadResultSignalType.String(): nil, uint64(ReadResultSignalType): nil,
ErrorSignalType.String(): nil, uint64(ErrorSignalType): nil,
}) })
user_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ user_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{
gql_id: Tree{ gql_id: Tree{
LinkSignalType.String(): nil, uint64(LinkSignalType): nil,
ReadSignalType.String(): nil, uint64(ReadSignalType): nil,
}, },
}) })
@ -80,8 +80,8 @@ func TestGQLServer(t *testing.T) {
ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Str == "server_started" return sig.Status == "server_started"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -107,7 +107,9 @@ func TestGQLServer(t *testing.T) {
}, },
} }
auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize()) n1_id_bytes, err := n1.ID.MarshalBinary()
fatalErr(t, err)
auth_username := base64.StdEncoding.EncodeToString(n1_id_bytes)
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err) fatalErr(t, err)
auth_password := base64.StdEncoding.EncodeToString(key_bytes) auth_password := base64.StdEncoding.EncodeToString(key_bytes)
@ -196,11 +198,11 @@ func TestGQLServer(t *testing.T) {
SubGQL(sub_1) SubGQL(sub_1)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Str == "stopped" return sig.Status == "stopped"
}) })
fatalErr(t, err) fatalErr(t, err)
} }
@ -208,7 +210,7 @@ func TestGQLServer(t *testing.T) {
func TestGQLDB(t *testing.T) { func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
TestUserNodeType := NodeType("TEST_USER") TestUserNodeType := NewNodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
fatalErr(t, err) fatalErr(t, err)
u1 := NewNode(ctx, nil, TestUserNodeType, 10, nil) u1 := NewNode(ctx, nil, TestUserNodeType, 10, nil)
@ -225,33 +227,31 @@ func TestGQLDB(t *testing.T) {
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID return sig.Status == "stopped" && sig.Source == gql.ID
}) })
fatalErr(t, err) fatalErr(t, err)
ser1, err := gql.Serialize() ser1, err := gql.Serialize(ctx)
ser2, err := u1.Serialize() ser2, err := u1.Serialize(ctx)
ser3, err := StopSignal.Serialize()
ctx.Log.Logf("test", "SER_1: \n%s\n\n", ser1) ctx.Log.Logf("test", "SER_1: \n%s\n\n", ser1)
ctx.Log.Logf("test", "SER_2: \n%s\n\n", ser2) ctx.Log.Logf("test", "SER_2: \n%s\n\n", ser2)
ctx.Log.Logf("test", "SER_3: \n%s\n\n", ser3)
// Clear all loaded nodes from the context so it loads them from the database // Clear all loaded nodes from the context so it loads them from the database
ctx.Nodes = map[NodeID]*Node{} ctx.nodeMap = map[NodeID]*Node{}
gql_loaded, err := LoadNode(ctx, gql.ID) gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded) listener_ext, err = GetExt[*ListenerExt](gql_loaded, GQLExtType)
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{} msgs = Messages{}
msgs = msgs.Add(gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID return sig.Status == "stopped" && sig.Source == gql_loaded.ID
}) })
fatalErr(t, err) fatalErr(t, err)
} }

@ -16,10 +16,10 @@ func AddNodeFields(object *graphql.Object) {
}) })
} }
func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type { func NewGQLNodeType(gql_name string, node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type {
var gql Type var gql Type
gql.Type = graphql.NewObject(graphql.ObjectConfig{ gql.Type = graphql.NewObject(graphql.ObjectConfig{
Name: string(node_type), Name: gql_name,
Interfaces: interfaces, Interfaces: interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool { IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(NodeResult) node, ok := p.Value.(NodeResult)

@ -6,7 +6,7 @@ import (
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") var SimpleListenerNodeType = NewNodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
listener_extension := NewListenerExt(buffer) listener_extension := NewListenerExt(buffer)

@ -12,18 +12,10 @@ func (ext *GroupExt) Type() ExtType {
return GroupExtType return GroupExtType
} }
func (ext *GroupExt) Serialize() ([]byte, error) { func (ext *GroupExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext) return json.Marshal(ext)
} }
func (ext *GroupExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*GroupExt)interface{}{
"members": func(ext *GroupExt) interface{} {
return ext.Members
},
})
}
func NewGroupExt(members map[NodeID]string) *GroupExt { func NewGroupExt(members map[NodeID]string) *GroupExt {
if members == nil { if members == nil {
members = map[NodeID]string{} members = map[NodeID]string{}

@ -18,19 +18,8 @@ func NewListenerExt(buffer int) *ListenerExt {
} }
} }
func (ext *ListenerExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ListenerExt)interface{}{
"buffer": func(ext *ListenerExt) interface{} {
return ext.Buffer
},
"chan": func(ext *ListenerExt) interface{} {
return ext.Chan
},
})
}
// Simple load function, unmarshal the buffer int from json // Simple load function, unmarshal the buffer int from json
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { func (ext *ListenerExt) DeserializeListenerExt(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer) err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Signal, ext.Buffer) ext.Chan = make(chan Signal, ext.Buffer)
return err return err
@ -51,6 +40,6 @@ func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal
return nil return nil
} }
func (ext *ListenerExt) Serialize() ([]byte, error) { func (ext *ListenerExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext.Buffer) return json.Marshal(ext.Buffer)
} }

@ -1,7 +1,6 @@
package graphvent package graphvent
import ( import (
"encoding/binary"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -15,119 +14,17 @@ const (
) )
type LockableExt struct{ type LockableExt struct{
State ReqState State ReqState `ext:""`
ReqID uuid.UUID ReqID *uuid.UUID `ext:""`
Owner *NodeID Owner *NodeID `ext:""`
PendingOwner *NodeID PendingOwner *NodeID `ext:""`
Requirements map[NodeID]ReqState Requirements map[NodeID]ReqState `ext:""`
}
func (ext *LockableExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*LockableExt)interface{}{
"owner": func(ext *LockableExt) interface{} {
return ext.Owner
},
"pending_owner": func(ext *LockableExt) interface{} {
return ext.PendingOwner
},
"requirements": func(ext *LockableExt) interface{} {
return ext.Requirements
},
})
} }
func (ext *LockableExt) Type() ExtType { func (ext *LockableExt) Type() ExtType {
return LockableExtType return LockableExtType
} }
func (ext *LockableExt) Serialize() ([]byte, error) {
ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements)))
if ext.Owner != nil {
bytes, err := ext.Owner.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[0:16], bytes)
}
if ext.PendingOwner != nil {
bytes, err := ext.PendingOwner.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[16:32], bytes)
}
binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements)))
ret[40] = byte(ext.State)
cur := 41
for req, state := range(ext.Requirements) {
bytes, err := req.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[cur:cur+16], bytes)
ret[cur+16] = byte(state)
cur += 17
}
return ret, nil
}
func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
cur := 0
all_zero := true
for _, b := range(data[cur:cur+16]) {
if all_zero == true && b != 0x00 {
all_zero = false
}
}
if all_zero == false {
tmp, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
ext.Owner = &tmp
}
cur += 16
all_zero = true
for _, b := range(data[cur:cur+16]) {
if all_zero == true && b != 0x00 {
all_zero = false
}
}
if all_zero == false {
tmp, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
ext.PendingOwner = &tmp
}
cur += 16
num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8]))
cur += 8
ext.State = ReqState(data[cur])
cur += 1
if num_requirements != 0 {
ext.Requirements = map[NodeID]ReqState{}
}
for i := 0; i < num_requirements; i++ {
id, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
cur += 16
state := ReqState(data[cur])
cur += 1
ext.Requirements[id] = state
}
return nil
}
func NewLockableExt(requirements []NodeID) *LockableExt { func NewLockableExt(requirements []NodeID) *LockableExt {
var reqs map[NodeID]ReqState = nil var reqs map[NodeID]ReqState = nil
if requirements != nil { if requirements != nil {
@ -148,21 +45,21 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{} msgs := Messages{}
signal := NewLockSignal("unlock") signal := NewLockSignal("unlock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target) msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs) return signal.Header().ID, ctx.Send(msgs)
} }
// Send the signal to lock a node from itself // Send the signal to lock a node from itself
func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{} msgs := Messages{}
signal := NewLockSignal("lock") signal := NewLockSignal("lock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target) msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs) return signal.Header().ID, ctx.Send(msgs)
} }
func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages { func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) Messages {
str := signal.Error str := signal.Error
log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str) ctx.Log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str)
msgs := Messages {} msgs := Messages {}
switch str { switch str {
@ -173,7 +70,7 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID,
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
if state == Locked { if state == Locked {
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id) msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), id)
} }
} }
} }
@ -185,51 +82,48 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID,
return msgs return msgs
} }
func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages { func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) Messages {
id := signal.NodeID
action := signal.Str
msgs := Messages {} msgs := Messages {}
if ext.State == Unlocked { if ext.State == Unlocked {
switch action { switch signal.Action {
case "add": case "add":
_, exists := ext.Requirements[id] _, exists := ext.Requirements[signal.NodeID]
if exists == true { if exists == true {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "already_requirement"), source)
} else { } else {
if ext.Requirements == nil { if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{} ext.Requirements = map[NodeID]ReqState{}
} }
ext.Requirements[id] = Unlocked ext.Requirements[signal.NodeID] = Unlocked
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_added"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_added"), source)
} }
case "remove": case "remove":
_, exists := ext.Requirements[id] _, exists := ext.Requirements[signal.NodeID]
if exists == false { if exists == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else { } else {
delete(ext.Requirements, id) delete(ext.Requirements, signal.NodeID)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_removed"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_removed"), source)
} }
default: default:
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "unknown_action"), source)
} }
} else { } else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source)
} }
return msgs return msgs
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) Messages {
state := signal.Str ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal.State)
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state)
msgs := Messages{} msgs := Messages{}
switch state { switch signal.State {
case "locked": case "locked":
state, found := ext.Requirements[source] state, found := ext.Requirements[source]
if found == false { if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else if state == Locking { } else if state == Locking {
if ext.State == Locking { if ext.State == Locking {
ext.Requirements[source] = Locked ext.Requirements[source] = Locked
@ -245,19 +139,19 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
if locked == reqs { if locked == reqs {
ext.State = Locked ext.State = Locked
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
} else { } else {
log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
} }
} else if ext.State == AbortingLock { } else if ext.State == AbortingLock {
ext.Requirements[source] = Unlocking ext.Requirements[source] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), source)
} }
} }
case "unlocked": case "unlocked":
state, found := ext.Requirements[source] state, found := ext.Requirements[source]
if found == false { if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else if state == Unlocking { } else if state == Unlocking {
ext.Requirements[source] = Unlocked ext.Requirements[source] = Unlocked
reqs := 0 reqs := 0
@ -274,13 +168,14 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
ext.State = Unlocked ext.State = Unlocked
if old_state == Unlocking { if old_state == Unlocking {
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner) ext.ReqID = nil
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner)
} else if old_state == AbortingLock { } else if old_state == AbortingLock {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner) msgs = msgs.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner)
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
} }
} else { } else {
log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs)
} }
} }
case "lock": case "lock":
@ -290,23 +185,24 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
new_owner := source new_owner := source
ext.PendingOwner = &new_owner ext.PendingOwner = &new_owner
ext.Owner = &new_owner ext.Owner = &new_owner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner) msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), new_owner)
} else { } else {
ext.State = Locking ext.State = Locking
ext.ReqID = signal.ID() id := signal.ID
ext.ReqID = &id
new_owner := source new_owner := source
ext.PendingOwner = &new_owner ext.PendingOwner = &new_owner
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
if state != Unlocked { if state != Unlocked {
log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
} }
ext.Requirements[id] = Locking ext.Requirements[id] = Locking
lock_signal := NewLockSignal("lock") lock_signal := NewLockSignal("lock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id) msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id)
} }
} }
} else { } else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source)
} }
case "unlock": case "unlock":
if ext.State == Locked { if ext.State == Locked {
@ -315,25 +211,26 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
new_owner := source new_owner := source
ext.PendingOwner = nil ext.PendingOwner = nil
ext.Owner = nil ext.Owner = nil
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner) msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), new_owner)
} else if source == *ext.Owner { } else if source == *ext.Owner {
ext.State = Unlocking ext.State = Unlocking
ext.ReqID = signal.ID() id := signal.ID
ext.ReqID = &id
ext.PendingOwner = nil ext.PendingOwner = nil
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
if state != Locked { if state != Locked {
log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
} }
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock") lock_signal := NewLockSignal("unlock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id) msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id)
} }
} }
} else { } else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_locked"), source)
} }
default: default:
log.Logf("lockable", "LOCK_ERR: unkown state %s", state) ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
} }
return msgs return msgs
} }
@ -342,25 +239,25 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
messages := Messages{} messages := Messages{}
switch signal.Direction() { switch signal.Header().Direction {
case Up: case Up:
if ext.Owner != nil { if ext.Owner != nil {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
messages = messages.Add(node.ID, node.Key, signal, *ext.Owner) messages = messages.Add(ctx, node.ID, node.Key, signal, *ext.Owner)
} }
} }
case Down: case Down:
for requirement, _ := range(ext.Requirements) { for requirement, _ := range(ext.Requirements) {
messages = messages.Add(node.ID, node.Key, signal, requirement) messages = messages.Add(ctx, node.ID, node.Key, signal, requirement)
} }
case Direct: case Direct:
switch signal.Type() { switch sig := signal.(type) {
case LinkSignalType: case *LinkSignal:
messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*IDStringSignal)) messages = ext.HandleLinkSignal(ctx, node, source, sig)
case LockSignalType: case *LockSignal:
messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) messages = ext.HandleLockSignal(ctx, node, source, sig)
case ErrorSignalType: case *ErrorSignal:
messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal)) messages = ext.HandleErrorSignal(ctx, node, source, sig)
default: default:
} }
default: default:

@ -7,7 +7,7 @@ import (
"crypto/rand" "crypto/rand"
) )
const TestLockableType = NodeType("TEST_LOCKABLE") var TestLockableType = NewNodeType("TEST_LOCKABLE")
func lockableTestContext(t *testing.T, logs []string) *Context { func lockableTestContext(t *testing.T, logs []string) *Context {
ctx := logTestContext(t, logs) ctx := logTestContext(t, logs)
@ -43,57 +43,24 @@ func TestLink(t *testing.T) {
) )
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID) msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.Error == "req_added" return sig.Error == "req_added"
}) })
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{} msgs = Messages{}
s := NewBaseSignal("TEST", Down) msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID)
msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
_, err = WaitForSignal(l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
msgs = Messages{}
msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
return sig.Error == "req_removed" return sig.Error == "req_removed"
}) })
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{}
s = NewBaseSignal("TEST", Down)
msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
select {
case <- l2_listener.Chan:
t.Fatal("Recevied message on l2 after removing link")
default:
}
} }
func Test10KLink(t *testing.T) { func Test10KLink(t *testing.T) {
@ -104,7 +71,7 @@ func Test10KLink(t *testing.T) {
listener_id := KeyID(l_pub) listener_id := KeyID(l_pub)
child_policy := NewPerNodePolicy(map[NodeID]Tree{ child_policy := NewPerNodePolicy(map[NodeID]Tree{
listener_id: Tree{ listener_id: Tree{
LockSignalType.String(): nil, uint64(LockSignalType): nil,
}, },
}) })
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
@ -125,7 +92,7 @@ func Test10KLink(t *testing.T) {
ctx.Log.Logf("test", "CREATED_10K") ctx.Log.Logf("test", "CREATED_10K")
l_policy := NewAllNodesPolicy(Tree{ l_policy := NewAllNodesPolicy(Tree{
LockSignalType.String(): nil, uint64(LockSignalType): nil,
}) })
listener := NewListenerExt(100000) listener := NewListenerExt(100000)
node := NewNode(ctx, listener_key, TestLockableType, 10000, node := NewNode(ctx, listener_key, TestLockableType, 10000,
@ -140,14 +107,14 @@ func Test10KLink(t *testing.T) {
_, err = LockLockable(ctx, node, node.ID) _, err = LockLockable(ctx, node, node.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, func(sig *LockSignal) bool {
return sig.Str == "locked" return sig.State == "locked"
}) })
fatalErr(t, err) fatalErr(t, err)
for _, _ = range(reqs) { for _, _ = range(reqs) {
_, err := WaitForSignal(listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool { _, err := WaitForSignal(listener.Chan, time.Millisecond*100, func(sig *LockSignal) bool {
return sig.Str == "locked" return sig.State == "locked"
}) })
fatalErr(t, err) fatalErr(t, err)
} }
@ -178,36 +145,36 @@ func TestLock(t *testing.T) {
l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
locked := func(sig *StringSignal) bool { locked := func(sig *LockSignal) bool {
return sig.Str == "locked" return sig.State == "locked"
} }
unlocked := func(sig *StringSignal) bool { unlocked := func(sig *LockSignal) bool {
return sig.Str == "unlocked" return sig.State == "unlocked"
} }
_, err := LockLockable(ctx, l0, l5.ID) _, err := LockLockable(ctx, l0, l5.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, locked)
fatalErr(t, err) fatalErr(t, err)
id, err := LockLockable(ctx, l1, l1.ID) id, err := LockLockable(ctx, l1, l1.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.Error == "not_unlocked" && sig.ReqID() == id return sig.Error == "not_unlocked" && sig.Header().ReqID == id
}) })
fatalErr(t, err) fatalErr(t, err)
_, err = UnlockLockable(ctx, l0, l5.ID) _, err = UnlockLockable(ctx, l0, l5.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked) _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, unlocked)
fatalErr(t, err) fatalErr(t, err)
_, err = LockLockable(ctx, l1, l1.ID) _, err = LockLockable(ctx, l1, l1.ID)
fatalErr(t, err) fatalErr(t, err)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool { _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *LockSignal) bool {
return sig.Str == "locked" return sig.State == "locked"
}) })
fatalErr(t, err) fatalErr(t, err)
} }

@ -29,22 +29,20 @@ const (
var ( var (
// Base NodeID, used as a special value // Base NodeID, used as a special value
ZeroUUID = uuid.UUID{} ZeroUUID = uuid.UUID{}
ZeroID = NodeID{ZeroUUID} ZeroID = NodeID(ZeroUUID)
) )
// A NodeID uniquely identifies a Node // A NodeID uniquely identifies a Node
type NodeID struct { type NodeID uuid.UUID
uuid.UUID func (id NodeID) MarshalBinary() ([]byte, error) {
return (uuid.UUID)(id).MarshalBinary()
} }
func (id NodeID) String() string {
func (id NodeID) Serialize() []byte { return (uuid.UUID)(id).String()
ser, _ := id.MarshalBinary()
return ser
} }
func IDFromBytes(bytes []byte) (NodeID, error) { func IDFromBytes(bytes []byte) (NodeID, error) {
id, err := uuid.FromBytes(bytes[:]) id, err := uuid.FromBytes(bytes)
return NodeID{id}, err return NodeID(id), err
} }
// Parse an ID from a string // Parse an ID from a string
@ -53,26 +51,17 @@ func ParseID(str string) (NodeID, error) {
if err != nil { if err != nil {
return NodeID{}, err return NodeID{}, err
} }
return NodeID{id_uuid}, nil return NodeID(id_uuid), nil
} }
// Generate a random NodeID // Generate a random NodeID
func RandID() NodeID { func RandID() NodeID {
return NodeID{uuid.New()} return NodeID(uuid.New())
}
// A Serializable has a type that can be used to map to it, and a function to serialize` the current state
type Serializable[I comparable] interface {
Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
Type() I
} }
// Extensions are data attached to nodes that process signals // Extensions are data attached to nodes that process signals
type Extension interface { type Extension interface {
Serializable[ExtType] Process(*Context, *Node, NodeID, Signal) Messages
Field(string)interface{}
Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages
} }
// A QueuedSignal is a Signal that has been Queued to trigger at a set time // A QueuedSignal is a Signal that has been Queued to trigger at a set time
@ -130,10 +119,10 @@ const (
Pending Pending
) )
func (node *Node) Allows(principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) { func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) {
pends := map[PolicyType]Messages{} pends := map[PolicyType]Messages{}
for policy_type, policy := range(node.Policies) { for policy_type, policy := range(node.Policies) {
msgs, resp := policy.Allows(principal_id, action, node) msgs, resp := policy.Allows(ctx, principal_id, action, node)
if resp == Allow { if resp == Allow {
return nil, Allow return nil, Allow
} else if resp == Pending { } else if resp == Pending {
@ -154,7 +143,7 @@ func (node *Node) QueueSignal(time time.Time, signal Signal) {
func (node *Node) DequeueSignal(id uuid.UUID) error { func (node *Node) DequeueSignal(id uuid.UUID) error {
idx := -1 idx := -1
for i, q := range(node.SignalQueue) { for i, q := range(node.SignalQueue) {
if q.Signal.ID() == id { if q.Signal.Header().ID == id {
idx = i idx = i
break break
} }
@ -202,16 +191,43 @@ func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
} }
func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} { type StringError string
exts := map[ExtType]map[string]interface{}{} func (err StringError) String() string {
return string(err)
}
func (err StringError) Error() string {
return err.String()
}
func (err StringError) MarshalBinary() ([]byte, error) {
return []byte(string(err)), nil
}
func NewErrorField(fstring string, args ...interface{}) SerializedValue {
str := StringError(fmt.Sprintf(fstring, args...))
str_ser, err := str.MarshalBinary()
if err != nil {
panic(err)
}
return SerializedValue{
TypeStack: []uint64{uint64(ErrorType)},
Data: str_ser,
}
}
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue {
exts := map[ExtType]map[string]SerializedValue{}
for ext_type, field_reqs := range(reqs) { for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{} fields := map[string]SerializedValue{}
for _, req := range(field_reqs) { for _, req := range(field_reqs) {
ext, exists := node.Extensions[ext_type] ext, exists := node.Extensions[ext_type]
if exists == false { if exists == false {
fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type) fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type)
} else { } else {
fields[req] = ext.Field(req) f, err := SerializeField(ctx, ext, req)
if err != nil {
fields[req] = NewErrorField(err.Error())
} else {
fields[req] = f
}
} }
} }
exts[ext_type] = fields exts[ext_type] = fields
@ -227,21 +243,40 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
// Perform startup actions // Perform startup actions
node.Process(ctx, ZeroID, &StartSignal) node.Process(ctx, ZeroID, NewStartSignal())
run := true
for true { for run == true {
var signal Signal var signal Signal
var source NodeID var source NodeID
select { select {
case msg := <- node.MsgChan: case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal) ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
ser, err := msg.Signal.Serialize() signal_type, exists := ctx.SignalTypes[reflect.TypeOf(msg.Signal).Elem()]
if exists == false {
ctx.Log.Logf("signal", "SIGNAL_NOT_REGISTERED: %+v", reflect.TypeOf(msg.Signal).Elem())
}
signal_ser, err := SerializeSignal(ctx, signal, signal_type)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal) ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal)
}
ser, err := signal_ser.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser)
continue continue
} }
sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...) dst_id_ser, err := msg.Dest.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err)
continue
}
src_id_ser, err := msg.Source.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err)
continue
}
sig_data := append(dst_id_ser, src_id_ser...)
sig_data = append(sig_data, ser...) sig_data = append(sig_data, ser...)
validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature)
if validated == false { if validated == false {
@ -251,26 +286,26 @@ func nodeLoop(ctx *Context, node *Node) error {
princ_id := KeyID(msg.Principal) princ_id := KeyID(msg.Principal)
if princ_id != node.ID { if princ_id != node.ID {
pends, resp := node.Allows(princ_id, msg.Signal.Permission()) pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission())
if resp == Deny { if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.Header().ID, "acl denied"), msg.Source)
ctx.Send(msgs) ctx.Send(msgs)
continue continue
} else if resp == Pending { } else if resp == Pending {
ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends) ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends)
timeout_signal := NewACLTimeoutSignal(msg.Signal.ID()) timeout_signal := NewACLTimeoutSignal(msg.Signal.Header().ID)
node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal) node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal)
msgs := Messages{} msgs := Messages{}
for policy_type, sigs := range(pends) { for policy_type, sigs := range(pends) {
for _, m := range(sigs) { for _, m := range(sigs) {
msgs = append(msgs, m) msgs = append(msgs, m)
node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} node.PendingSignals[m.Signal.Header().ID] = PendingSignal{policy_type, false, msg.Signal.Header().ID}
} }
} }
node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} node.PendingACLs[msg.Signal.Header().ID] = PendingACL{len(msgs), timeout_signal.ID, msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source}
ctx.Send(msgs) ctx.Send(msgs)
continue continue
} else if resp == Allow { } else if resp == Allow {
@ -290,7 +325,7 @@ func nodeLoop(ctx *Context, node *Node) error {
t := node.NextSignal.Time t := node.NextSignal.Time
i := -1 i := -1
for j, queued := range(node.SignalQueue) { for j, queued := range(node.SignalQueue) {
if queued.Signal.ID() == node.NextSignal.Signal.ID() { if queued.Signal.Header().ID == node.NextSignal.Signal.Header().ID {
i = j i = j
break break
} }
@ -304,26 +339,26 @@ func nodeLoop(ctx *Context, node *Node) error {
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
if node.NextSignal == nil { if node.NextSignal == nil {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan) ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, signal, t, node.TimeoutChan)
} else { } else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time) ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time)
} }
} }
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
info, waiting := node.PendingSignals[signal.ReqID()] info, waiting := node.PendingSignals[signal.Header().ReqID]
if waiting == true { if waiting == true {
if info.Found == false { if info.Found == false {
info.Found = true info.Found = true
node.PendingSignals[signal.ReqID()] = info node.PendingSignals[signal.Header().ReqID] = info
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID] req_info, exists := node.PendingACLs[info.ID]
if exists == true { if exists == true {
req_info.Counter -= 1 req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, signal) req_info.Responses = append(req_info.Responses, signal)
allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal) allowed := node.Policies[info.Policy].ContinueAllows(ctx, req_info, signal)
if allowed == Allow { if allowed == Allow {
ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal)
signal = req_info.Signal signal = req_info.Signal
@ -337,7 +372,7 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response // Send the denied response
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source)
err := ctx.Send(msgs) err := ctx.Send(msgs)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err) ctx.Log.Logf("signal", "SEND_ERR: %s", err)
@ -355,25 +390,20 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
} }
// Handle node signals switch sig := signal.(type) {
if signal.Type() == StopSignalType { case *StopSignal:
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "stopped"), source)
ctx.Send(msgs) ctx.Send(msgs)
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) node.Process(ctx, node.ID, NewStatusSignal(node.ID, "stopped"))
break run = false
} else if signal.Type() == ReadSignalType { case *ReadSignal:
read_signal, ok := signal.(*ReadSignal) result := node.ReadFields(ctx, sig.Extensions)
if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else {
result := node.ReadFields(read_signal.Extensions)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "read_done"), source)
ctx.Send(msgs) ctx.Send(msgs)
} }
}
node.Process(ctx, source, signal) node.Process(ctx, source, signal)
// assume that processing a signal means that this nodes state changed // assume that processing a signal means that this nodes state changed
@ -401,8 +431,8 @@ type Message struct {
} }
type Messages []*Message type Messages []*Message
func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages {
msg, err := NewMessage(dest, source, principal, signal) msg, err := NewMessage(ctx, dest, source, principal, signal)
if err != nil { if err != nil {
panic(err) panic(err)
} else { } else {
@ -411,13 +441,31 @@ func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Sig
return msgs return msgs
} }
func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) {
ser, err := signal.Serialize() signal_type, exists := ctx.SignalTypes[reflect.TypeOf(signal)]
if exists == false {
return nil, fmt.Errorf("Cannot put %+v in a message, not a known signal type", reflect.TypeOf(signal))
}
signal_ser, err := SerializeSignal(ctx, signal, signal_type)
if err != nil {
return nil, err
}
ser, err := signal_ser.MarshalBinary()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sig_data := append(dest.Serialize(), source.Serialize()...) dest_ser, err := dest.MarshalBinary()
if err != nil {
return nil, err
}
source_ser, err := source.MarshalBinary()
if err != nil {
return nil, err
}
sig_data := append(dest_ser, source_ser...)
sig_data = append(sig_data, ser...) sig_data = append(sig_data, ser...)
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
@ -435,7 +483,7 @@ func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal
} }
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type()) ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal)
messages := Messages{} messages := Messages{}
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
@ -449,120 +497,86 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
return ctx.Send(messages) return ctx.Send(messages)
} }
func GetCtx[T Extension, C any](ctx *Context) (C, error) { func GetCtx[T Extension, C any](ctx *Context, ext_type ExtType) (C, error) {
var zero T
var zero_ctx C var zero_ctx C
ext_type := zero.Type() ext_info, ok := ctx.Extensions[ext_type]
type_hash := Hash(ext_type)
ext_info, ok := ctx.Extensions[type_hash]
if ok == false { if ok == false {
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type) return zero_ctx, fmt.Errorf("%+v is not an extension in ctx", ext_type)
} }
ext_ctx, ok := ext_info.Data.(C) ext_ctx, ok := ext_info.Data.(C)
if ok == false { if ok == false {
return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero)) return zero_ctx, fmt.Errorf("context for %+v is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero_ctx))
} }
return ext_ctx, nil return ext_ctx, nil
} }
func GetExt[T Extension](node *Node) (T, error) { func GetExt[T Extension](node *Node, ext_type ExtType) (T, error) {
var zero T var zero T
ext_type := zero.Type()
ext, exists := node.Extensions[ext_type] ext, exists := node.Extensions[ext_type]
if exists == false { if exists == false {
return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions) return zero, fmt.Errorf("%+v does not have %+v extension - %+v", node.ID, ext_type, node.Extensions)
} }
ret, ok := ext.(T) ret, ok := ext.(T)
if ok == false { if ok == false {
return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero)) return zero, fmt.Errorf("%+v in %+v is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
} }
return ret, nil return ret, nil
} }
func (node *Node) Serialize() ([]byte, error) { func (node *Node) Serialize(ctx *Context) (SerializedValue, error) {
extensions := make([]ExtensionDB, len(node.Extensions)) if node == nil {
qsignals := make([]QSignalDB, len(node.SignalQueue)) return SerializedValue{}, fmt.Errorf("Cannot serialize nil Node")
policies := make([]PolicyDB, len(node.Policies)) }
node_bytes := make([]byte, 8 * 3)
binary.BigEndian.PutUint64(node_bytes[0:8], uint64(len(node.Extensions)))
binary.BigEndian.PutUint64(node_bytes[8:16], uint64(len(node.Policies)))
binary.BigEndian.PutUint64(node_bytes[16:24], uint64(len(node.SignalQueue)))
key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil { if err != nil {
return nil, err return SerializedValue{}, err
} }
node_db := NodeDB{ key_val := SerializedValue{
Header: NodeDBHeader{ TypeStack: []uint64{uint64(NodeKeyType)},
Magic: NODE_DB_MAGIC, Data: key_bytes,
TypeHash: Hash(node.Type),
KeyLength: uint32(len(key_bytes)),
BufferSize: node.BufferSize,
NumExtensions: uint32(len(extensions)),
NumPolicies: uint32(len(policies)),
NumQueuedSignals: uint32(len(node.SignalQueue)),
},
Extensions: extensions,
Policies: policies,
QueuedSignals: qsignals,
KeyBytes: key_bytes,
} }
key_ser, err := key_val.MarshalBinary()
i := 0
for ext_type, info := range(node.Extensions) {
ser, err := info.Serialize()
if err != nil { if err != nil {
return nil, err return SerializedValue{}, err
}
node_db.Extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: Hash(ext_type),
Length: uint64(len(ser)),
},
Data: ser,
}
i += 1
} }
node_bytes = append(node_bytes, key_ser...)
for i, qsignal := range(node.SignalQueue) { for ext_type, ext := range(node.Extensions) {
ser, err := qsignal.Signal.Serialize() ctx.Log.Logf("serialize", "SERIALIZING_EXTENSION: %+v", ext)
ext_ser, err := SerializeExtension(ctx, ext, ext_type)
if err != nil { if err != nil {
return nil, err return SerializedValue{}, err
}
node_db.QueuedSignals[i] = QSignalDB{
QSignalDBHeader{
qsignal.Time,
Hash(qsignal.Signal.Type()),
uint64(len(ser)),
},
ser,
} }
} ext_bytes, err := ext_ser.MarshalBinary()
i = 0
for _, policy := range(node.Policies) {
ser, err := policy.Serialize()
if err != nil { if err != nil {
return nil, err return SerializedValue{}, err
} }
node_db.Policies[i] = PolicyDB{ node_bytes = append(node_bytes, ext_bytes...)
PolicyDBHeader{
Hash(policy.Type()),
uint64(len(ser)),
},
ser,
} }
node_value := SerializedValue{
TypeStack: []uint64{uint64(node.Type)},
Data: node_bytes,
} }
return node_db.Serialize(), nil return node_value, nil
} }
func KeyID(pub ed25519.PublicKey) NodeID { func KeyID(pub ed25519.PublicKey) NodeID {
str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID{str} return NodeID(id)
} }
// Create a new node in memory and start it's event loop // Create a new node in memory and start it's event loop
@ -584,24 +598,28 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic("Attempted to create an existing node") panic("Attempted to create an existing node")
} }
def, exists := ctx.Types[Hash(node_type)] def, exists := ctx.Nodes[node_type]
if exists == false { if exists == false {
panic("Node type %s not registered in Context") panic("Node type %s not registered in Context")
} }
ext_map := map[ExtType]Extension{} ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) { for _, ext := range(extensions) {
_, exists := ext_map[ext.Type()] ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext)]
if exists == false {
panic(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext)))
}
_, exists = ext_map[ext_type]
if exists == true { if exists == true {
panic("Cannot add the same extension to a node twice") panic("Cannot add the same extension to a node twice")
} }
ext_map[ext.Type()] = ext ext_map[ext_type] = ext
} }
for _, required_ext := range(def.Extensions) { for _, required_ext := range(def.Extensions) {
_, exists := ext_map[required_ext] _, exists := ext_map[required_ext]
if exists == false { if exists == false {
panic(fmt.Sprintf("%s requires %s", node_type, required_ext)) panic(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
} }
} }
@ -610,9 +628,9 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
} }
default_policy := NewAllNodesPolicy(Tree{ default_policy := NewAllNodesPolicy(Tree{
ErrorSignalType.String(): nil, uint64(ErrorSignalType): nil,
ReadResultSignalType.String(): nil, uint64(ReadResultSignalType): nil,
StatusSignalType.String(): nil, uint64(StatusSignalType): nil,
}) })
all_nodes_policy, exists := policies[AllNodesPolicyType] all_nodes_policy, exists := policies[AllNodesPolicyType]
@ -642,252 +660,32 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic(err) panic(err)
} }
node.Process(ctx, ZeroID, &NewSignal) node.Process(ctx, ZeroID, NewCreateSignal())
go runNode(ctx, node) go runNode(ctx, node)
return node return node
} }
type PolicyDBHeader struct {
TypeHash uint64
Length uint64
}
type PolicyDB struct {
Header PolicyDBHeader
Data []byte
}
type QSignalDBHeader struct {
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct {
Magic uint32
NumExtensions uint32
NumPolicies uint32
NumQueuedSignals uint32
BufferSize uint32
KeyLength uint32
TypeHash uint64
}
type NodeDB struct {
Header NodeDBHeader
Extensions []ExtensionDB
Policies []PolicyDB
QueuedSignals []QSignalDB
KeyBytes []byte
}
//TODO: add size safety checks
func NewNodeDB(data []byte) (NodeDB, error) {
var zero NodeDB
ptr := 0
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
num_policies := binary.BigEndian.Uint32(data[8:12])
num_queued_signals := binary.BigEndian.Uint32(data[12:16])
buffer_size := binary.BigEndian.Uint32(data[16:20])
key_length := binary.BigEndian.Uint32(data[20:24])
node_type_hash := binary.BigEndian.Uint64(data[24:32])
ptr += NODE_DB_HEADER_LEN
if magic != NODE_DB_MAGIC {
return zero, fmt.Errorf("header has incorrect magic 0x%x", magic)
}
key_bytes := make([]byte, key_length)
n := copy(key_bytes, data[ptr:(ptr+int(key_length))])
if n != int(key_length) {
return zero, fmt.Errorf("not enough key bytes: %d", n)
}
ptr += int(key_length)
extensions := make([]ExtensionDB, num_extensions)
for i, _ := range(extensions) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(EXTENSION_DB_HEADER_LEN)
data_end := data_start + length
ext_data := cur[data_start:data_end]
extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: type_hash,
Length: length,
},
Data: ext_data,
}
ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
policies := make([]PolicyDB, num_policies)
for i, _ := range(policies) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(POLICY_DB_HEADER_LEN)
data_end := data_start + length
policy_data := cur[data_start:data_end]
policies[i] = PolicyDB{
PolicyDBHeader{
type_hash,
length,
},
policy_data,
}
ptr += int(POLICY_DB_HEADER_LEN + length)
}
queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) {
cur := data[ptr:]
// TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented
// Right now causes segfault because any saved signal is loaded as nil
unix_milli := binary.BigEndian.Uint64(cur[0:8])
type_hash := binary.BigEndian.Uint64(cur[8:16])
signal_size := binary.BigEndian.Uint64(cur[16:24])
signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)]
queued_signals[i] = QSignalDB{
QSignalDBHeader{
time.UnixMilli(int64(unix_milli)),
type_hash,
signal_size,
},
signal_data,
}
ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size)
}
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
TypeHash: node_type_hash,
BufferSize: buffer_size,
KeyLength: key_length,
NumExtensions: num_extensions,
NumQueuedSignals: num_queued_signals,
},
KeyBytes: key_bytes,
Extensions: extensions,
QueuedSignals: queued_signals,
}, nil
}
func (header NodeDBHeader) Serialize() []byte {
if header.Magic != NODE_DB_MAGIC {
panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic))
}
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies)
binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[16:20], header.BufferSize)
binary.BigEndian.PutUint32(ret[20:24], header.KeyLength)
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
return ret
}
func (node NodeDB) Serialize() []byte {
ser := node.Header.Serialize()
ser = append(ser, node.KeyBytes...)
for _, extension := range(node.Extensions) {
ser = append(ser, extension.Serialize()...)
}
for _, policy := range(node.Policies) {
ser = append(ser, policy.Serialize()...)
}
for _, qsignal := range(node.QueuedSignals) {
ser = append(ser, qsignal.Serialize()...)
}
return ser
}
func (header QSignalDBHeader) Serialize() []byte {
ret := make([]byte, QSIGNAL_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], uint64(header.Time.UnixMilli()))
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash)
binary.BigEndian.PutUint64(ret[16:24], header.Length)
return ret
}
func (qsignal QSignalDB) Serialize() []byte {
header_bytes := qsignal.Header.Serialize()
return append(header_bytes, qsignal.Data...)
}
func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[8:16], header.Length)
return ret
}
func (extension ExtensionDB) Serialize() []byte {
header_bytes := extension.Header.Serialize()
return append(header_bytes, extension.Data...)
}
func (header PolicyDBHeader) Serialize() []byte {
ret := make([]byte, POLICY_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[0:8], header.Length)
return ret
}
func (policy PolicyDB) Serialize() []byte {
header_bytes := policy.Header.Serialize()
return append(header_bytes, policy.Data...)
}
// Write a node to the database // Write a node to the database
func WriteNode(ctx *Context, node *Node) error { func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
bytes, err := node.Serialize() node_serialized, err := node.Serialize(ctx)
if err != nil {
return err
}
bytes, err := node_serialized.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
ctx.Log.Logf("db_data", "DB_DATA: %+v", bytes) ctx.Log.Logf("db_data", "DB_DATA: %+v", bytes)
id_bytes := node.ID.Serialize() id_bytes, err := node.ID.MarshalBinary()
if err != nil {
return err
}
ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes) ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes)
return ctx.DB.Update(func(txn *badger.Txn) error { return ctx.DB.Update(func(txn *badger.Txn) error {
@ -895,11 +693,15 @@ func WriteNode(ctx *Context, node *Node) error {
}) })
} }
//TODO: fix after capnp
func LoadNode(ctx * Context, id NodeID) (*Node, error) { func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOADING_NODE: %s", id) ctx.Log.Logf("db", "LOADING_NODE: %s", id)
var bytes []byte var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error { err := ctx.DB.View(func(txn *badger.Txn) error {
id_bytes := id.Serialize() id_bytes, err := id.MarshalBinary()
if err != nil {
return err
}
ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes) ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes)
item, err := txn.Get(id_bytes) item, err := txn.Get(id_bytes)
if err != nil { if err != nil {
@ -917,137 +719,18 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err return nil, err
} }
// Parse the bytes from the DB num_extensions := binary.BigEndian.Uint64(bytes[0:8])
node_db, err := NewNodeDB(bytes) num_policies := binary.BigEndian.Uint64(bytes[8:16])
if err != nil { num_signals := binary.BigEndian.Uint64(bytes[16:24])
return nil, err print(num_extensions)
} print(num_policies)
print(num_signals)
policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies)
for _, policy_db := range(node_db.Policies) {
policy_info, exists := ctx.Policies[policy_db.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash)
}
policy, err := policy_info.Load(ctx, policy_db.Data)
if err != nil {
return nil, err
}
policies[policy_info.Type] = policy /*
}
key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil {
return nil, err
}
var key ed25519.PrivateKey
switch k := key_raw.(type) {
case ed25519.PrivateKey:
key = k
default:
return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k))
}
key_id := KeyID(key.Public().(ed25519.PublicKey))
if key_id != id {
return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id)
}
node_type, known := ctx.Types[node_db.Header.TypeHash]
if known == false {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
}
signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals)
for i, qsignal := range(node_db.QueuedSignals) {
sig_info, exists := ctx.Signals[qsignal.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash)
}
signal, err := sig_info.Load(ctx, qsignal.Data)
if err != nil {
return nil, err
}
signal_queue[i] = QueuedSignal{signal, qsignal.Header.Time}
}
next_signal, timeout_chan := SoonestSignal(signal_queue)
node := &Node{
Key: key,
ID: key_id,
Type: node_type.Type,
Extensions: map[ExtType]Extension{},
Policies: policies,
MsgChan: make(chan *Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan,
SignalQueue: signal_queue,
NextSignal: next_signal,
}
ctx.AddNode(id, node) ctx.AddNode(id, node)
found_extensions := []ExtType{}
// Parse each of the extensions from the db
for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
def, known := ctx.Extensions[type_hash]
if known == false {
return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash)
}
ctx.Log.Logf("db", "DB_EXTENSION_LOADING: %s/%s", id, def.Type)
extension, err := def.Load(ctx, ext_db.Data)
if err != nil {
return nil, err
}
node.Extensions[def.Type] = extension
found_extensions = append(found_extensions, def.Type)
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s/%s - %+v", id, def.Type, extension)
}
missing_extensions := []ExtType{}
for _, ext := range(node_type.Extensions) {
found := false
for _, found_ext := range(found_extensions) {
if found_ext == ext {
found = true
break
}
}
if found == false {
missing_extensions = append(missing_extensions, ext)
}
}
if len(missing_extensions) > 0 {
return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions)
}
extra_extensions := []ExtType{}
for _, found_ext := range(found_extensions) {
found := false
for _, ext := range(node_type.Extensions) {
if ext == found_ext {
found = true
break
}
}
if found == false {
extra_extensions = append(extra_extensions, found_ext)
}
}
if len(extra_extensions) > 0 {
ctx.Log.Logf("db", "DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions)
}
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
go runNode(ctx, node) go runNode(ctx, node)
*/
return node, nil return nil, nil
} }

@ -8,21 +8,21 @@ import (
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"signal", "node", "db"}) ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize"})
node_type := NodeType("test") node_type := NewNodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil)) node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil))
ctx.Nodes = map[NodeID]*Node{} ctx.nodeMap = map[NodeID]*Node{}
_, err = ctx.GetNode(node.ID) _, err = ctx.getNode(node.ID)
fatalErr(t, err) fatalErr(t, err)
} }
func TestNodeRead(t *testing.T) { func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
node_type := NodeType("TEST") node_type := NewNodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType})
fatalErr(t, err) fatalErr(t, err)
@ -38,27 +38,27 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)
n1_policy := NewPerNodePolicy(map[NodeID]Tree{ n1_policy := NewPerNodePolicy(map[NodeID]Tree{
n2_id: Tree{ n2_id: {
ReadSignalType.String(): nil, uint64(ReadSignalType): nil,
}, },
}) })
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener) n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), n2_listener)
n1 := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{ n1 := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{
PerNodePolicyType: &n1_policy, PerNodePolicyType: &n1_policy,
}, NewGroupExt(nil), NewECDHExt()) }, NewGroupExt(nil))
read_sig := NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: {"members"},
}) })
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(n2.ID, n2.Key, read_sig, n1.ID) msgs = msgs.Add(ctx, n2.ID, n2.Key, read_sig, n1.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool {
return true return true
}) })
fatalErr(t, err) fatalErr(t, err)

@ -1,35 +1,26 @@
package graphvent package graphvent
import ( import (
"encoding/json"
)
const (
MemberOfPolicyType = PolicyType("USER_OF")
RequirementOfPolicyType = PolicyType("REQUIEMENT_OF")
PerNodePolicyType = PolicyType("PER_NODE")
AllNodesPolicyType = PolicyType("ALL_NODES")
) )
type Policy interface { type Policy interface {
Serializable[PolicyType] Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
ContinueAllows(current PendingACL, signal Signal)RuleResult
// Merge with another policy of the same underlying type // Merge with another policy of the same underlying type
Merge(Policy) Policy Merge(Policy) Policy
// Make a copy of this policy // Make a copy of this policy
Copy() Policy Copy() Policy
} }
func (policy *AllNodesPolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { func (policy *AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
return nil, policy.Rules.Allows(action) return nil, policy.Rules.Allows(action)
} }
func (policy *AllNodesPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { func (policy *AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny return Deny
} }
func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { func (policy *PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
for id, actions := range(policy.NodeRules) { for id, actions := range(policy.NodeRules) {
if id != principal_id { if id != principal_id {
continue continue
@ -39,7 +30,7 @@ func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node
return nil, Deny return nil, Deny
} }
func (policy *PerNodePolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { func (policy *PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny return Deny
} }
@ -57,7 +48,7 @@ func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy {
} }
} }
func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal) sig, ok := signal.(*ReadResultSignal)
if ok == false { if ok == false {
return Deny return Deny
@ -68,7 +59,17 @@ func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Sig
return Deny return Deny
} }
requirements, ok := ext["requirements"].(map[NodeID]string) reqs_ser, ok := ext["requirements"]
if ok == false {
return Deny
}
reqs_if, err := DeserializeValue(ctx, reqs_ser)
if err != nil {
return Deny
}
requirements, ok := reqs_if.(map[NodeID]ReqState)
if ok == false { if ok == false {
return Deny return Deny
} }
@ -96,7 +97,7 @@ func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy {
} }
} }
func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal) sig, ok := signal.(*ReadResultSignal)
if ok == false { if ok == false {
return Deny return Deny
@ -107,7 +108,17 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal)
return Deny return Deny
} }
members, ok := group_ext_data["members"].(map[NodeID]string) members_ser, ok := group_ext_data["members"]
if ok == false {
return Deny
}
members_if, err := DeserializeValue(ctx, members_ser)
if err != nil {
return Deny
}
members, ok := members_if.(map[NodeID]string)
if ok == false { if ok == false {
return Deny return Deny
} }
@ -122,11 +133,11 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal)
} }
// Send a read signal to Group to check if principal_id is a member of it // Send a read signal to Group to check if principal_id is a member of it
func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
msgs := Messages{} msgs := Messages{}
for id, rule := range(policy.NodeRules) { for id, rule := range(policy.NodeRules) {
if id == node.ID { if id == node.ID {
ext, err := GetExt[*GroupExt](node) ext, err := GetExt[*GroupExt](node, GroupExtType)
if err == nil { if err == nil {
for member, _ := range(ext.Members) { for member, _ := range(ext.Members) {
if member == principal_id { if member == principal_id {
@ -137,7 +148,7 @@ func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Nod
} }
} }
} else { } else {
msgs = msgs.Add(node.ID, node.Key, NewReadSignal(map[ExtType][]string{ msgs = msgs.Add(ctx, node.ID, node.Key, NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: []string{"members"},
}), id) }), id)
} }
@ -238,7 +249,7 @@ func (policy *AllNodesPolicy) Copy() Policy {
} }
} }
type Tree map[string]Tree type Tree map[uint64]Tree
func (rule Tree) Allows(action Tree) RuleResult { func (rule Tree) Allows(action Tree) RuleResult {
// If the current rule is nil, it's a wildcard and any action being processed is allowed // If the current rule is nil, it's a wildcard and any action being processed is allowed
@ -285,14 +296,6 @@ func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType return PerNodePolicyType
} }
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
func NewAllNodesPolicy(rules Tree) AllNodesPolicy { func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
return AllNodesPolicy{ return AllNodesPolicy{
Rules: rules, Rules: rules,
@ -307,15 +310,7 @@ func (policy *AllNodesPolicy) Type() PolicyType {
return AllNodesPolicyType return AllNodesPolicyType
} }
func (policy *AllNodesPolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
var DefaultPolicy = NewAllNodesPolicy(Tree{ var DefaultPolicy = NewAllNodesPolicy(Tree{
ErrorSignalType.String(): nil, uint64(ErrorSignalType): nil,
ReadResultSignalType.String(): nil, uint64(ReadResultSignalType): nil,
}) })

@ -1,49 +1,29 @@
package graphvent package graphvent
import ( import (
"time"
"fmt" "fmt"
"encoding/json" "time"
"encoding/binary"
"crypto" "capnproto.org/go/capnp/v3"
"crypto/ed25519"
"crypto/ecdh"
"crypto/rand"
"crypto/aes"
"crypto/cipher"
"github.com/google/uuid" "github.com/google/uuid"
schema "github.com/mekkanized/graphvent/signal"
) )
type SignalDirection int type SignalDirection int
const ( const (
StopSignalType = SignalType("STOP")
NewSignalType = SignalType("NEW")
StartSignalType = SignalType("START")
ErrorSignalType = SignalType("ERROR")
StatusSignalType = SignalType("STATUS")
LinkSignalType = SignalType("LINK")
LockSignalType = SignalType("LOCK")
ReadSignalType = SignalType("READ")
ReadResultSignalType = SignalType("READ_RESULT")
ECDHSignalType = SignalType("ECDH")
ECDHProxySignalType = SignalType("ECDH_PROXY")
ACLTimeoutSignalType = SignalType("ACL_TIMEOUT")
Up SignalDirection = iota Up SignalDirection = iota
Down Down
Direct Direct
) )
type SignalType string type SignalHeader struct {
func (signal_type SignalType) String() string { return string(signal_type) } Direction SignalDirection
func (signal_type SignalType) Prefix() string { return "SIGNAL: " } ID uuid.UUID
ReqID uuid.UUID
}
type Signal interface { type Signal interface {
Serializable[SignalType] Header() *SignalHeader
String() string
Direction() SignalDirection
ID() uuid.UUID
ReqID() uuid.UUID
Permission() Tree Permission() Tree
} }
@ -59,7 +39,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU
if signal == nil { if signal == nil {
return nil, fmt.Errorf("LISTENER_CLOSED") return nil, fmt.Errorf("LISTENER_CLOSED")
} }
if signal.ReqID() == req_id { if signal.Header().ReqID == req_id {
return signal, nil return signal, nil
} }
case <-timeout_channel: case <-timeout_channel:
@ -69,7 +49,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU
return nil, fmt.Errorf("UNREACHABLE") return nil, fmt.Errorf("UNREACHABLE")
} }
func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check func(S)bool) (S, error) {
var zero S var zero S
var timeout_channel <- chan time.Time var timeout_channel <- chan time.Time
if timeout > 0 { if timeout > 0 {
@ -79,493 +59,396 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal
select { select {
case signal := <- listener: case signal := <- listener:
if signal == nil { if signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) return zero, fmt.Errorf("LISTENER_CLOSED")
} }
if signal.Type() == signal_type {
sig, ok := signal.(S) sig, ok := signal.(S)
if ok == true { if ok == true {
if check(sig) == true { if check(sig) == true {
return sig, nil return sig, nil
} }
} }
}
case <-timeout_channel: case <-timeout_channel:
return zero, fmt.Errorf("LISTENER_TIMEOUT: %s", signal_type) return zero, fmt.Errorf("LISTENER_TIMEOUT")
} }
} }
return zero, fmt.Errorf("LOOP_ENDED") return zero, fmt.Errorf("LOOP_ENDED")
} }
type BaseSignal struct { func NewSignalHeader(direction SignalDirection) SignalHeader {
SignalDirection SignalDirection `json:"direction"` id := uuid.New()
SignalType SignalType `json:"type"` header := SignalHeader{
UUID uuid.UUID `json:"id"` ID: id,
ReqUUID uuid.UUID `json:"req_uuid"` ReqID: id,
Direction: direction,
} }
return header
func (signal *BaseSignal) ReqID() uuid.UUID {
return signal.ReqUUID
} }
func (signal *BaseSignal) String() string { func NewRespHeader(req_id uuid.UUID, direction SignalDirection) SignalHeader {
ser, _ := json.Marshal(signal) header := SignalHeader{
return string(ser) ID: uuid.New(),
ReqID: req_id,
Direction: direction,
}
return header
} }
func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error { func SerializeHeader(header SignalHeader, root schema.SignalHeader) error {
return json.Unmarshal(data, signal) root.SetDirection(uint8(header.Direction))
id_ser, err := header.ID.MarshalBinary()
if err != nil {
return err
} }
root.SetId(id_ser)
func (signal *BaseSignal) ID() uuid.UUID { req_id_ser, err := header.ReqID.MarshalBinary()
return signal.UUID if err != nil {
return err
}
root.SetReqID(req_id_ser)
return nil
} }
func (signal *BaseSignal) Type() SignalType { func DeserializeHeader(header schema.SignalHeader) (SignalHeader, error) {
return signal.SignalType id_ser, err := header.Id()
if err != nil {
return SignalHeader{}, err
}
id, err := uuid.FromBytes(id_ser)
if err != nil {
return SignalHeader{}, err
} }
func (signal *BaseSignal) Permission() Tree { req_id_ser, err := header.ReqID()
return Tree{ if err != nil {
string(signal.Type()): Tree{}, return SignalHeader{}, err
} }
req_id, err := uuid.FromBytes(req_id_ser)
if err != nil {
return SignalHeader{}, err
} }
func (signal *BaseSignal) Direction() SignalDirection { return SignalHeader{
return signal.SignalDirection ID: id,
ReqID: req_id,
Direction: SignalDirection(header.Direction()),
}, nil
} }
func (signal *BaseSignal) Serialize() ([]byte, error) { type CreateSignal struct {
return json.Marshal(signal) SignalHeader
} }
func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { func (signal *CreateSignal) Header() *SignalHeader {
id := uuid.New() return &signal.SignalHeader
signal := BaseSignal{ }
UUID: id, func (signal *CreateSignal) Permission() Tree {
ReqUUID: id, return Tree{
SignalDirection: direction, uint64(CreateSignalType): nil,
SignalType: signal_type,
} }
return signal
} }
func NewRespSignal(id uuid.UUID, signal_type SignalType, direction SignalDirection) BaseSignal { func NewCreateSignal() *CreateSignal {
signal := BaseSignal{ return &CreateSignal{
UUID: uuid.New(), NewSignalHeader(Direct),
ReqUUID: id,
SignalDirection: direction,
SignalType: signal_type,
} }
return signal
} }
var NewSignal = NewBaseSignal(NewSignalType, Direct) type StartSignal struct {
var StartSignal = NewBaseSignal(StartSignalType, Direct) SignalHeader
var StopSignal = NewBaseSignal(StopSignalType, Direct)
type IDSignal struct {
BaseSignal
NodeID `json:"id"`
} }
func (signal *StartSignal) Header() *SignalHeader {
func (signal *IDSignal) Serialize() ([]byte, error) { return &signal.SignalHeader
return json.Marshal(signal)
} }
func (signal *StartSignal) Permission() Tree {
type StringSignal struct { return Tree{
BaseSignal uint64(StartSignalType): nil,
Str string `json:"state"`
} }
func (signal *StringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
} }
func NewStartSignal() *StartSignal {
func (signal *StringSignal) Serialize() ([]byte, error) { return &StartSignal{
return json.Marshal(&signal) NewSignalHeader(Direct),
} }
type ErrorSignal struct {
BaseSignal
Error string
} }
func (signal *ErrorSignal) String() string { type StopSignal struct {
ser, _ := json.Marshal(signal) SignalHeader
return string(ser)
} }
func (signal *StopSignal) Header() *SignalHeader {
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { return &signal.SignalHeader
return &ErrorSignal{
NewRespSignal(req_id, ErrorSignalType, Direct),
fmt.Sprintf(fmt_string, args...),
} }
func (signal *StopSignal) Permission() Tree {
return Tree{
uint64(StopSignalType): nil,
} }
func NewACLTimeoutSignal(req_id uuid.UUID) Signal {
sig := NewRespSignal(req_id, ACLTimeoutSignalType, Direct)
return &sig
} }
func NewStopSignal() *StopSignal {
type IDStringSignal struct { return &StopSignal{
BaseSignal NewSignalHeader(Direct),
NodeID NodeID `json:"node_id"`
Str string `json:"string"`
} }
func (signal *IDStringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
} }
func (signal *IDStringSignal) Serialize() ([]byte, error) { type ErrorSignal struct {
return json.Marshal(signal) SignalHeader
Error string
} }
func (signal *ErrorSignal) Header() *SignalHeader {
func NewStatusSignal(status string, source NodeID) Signal { return &signal.SignalHeader
return &IDStringSignal{
BaseSignal: NewBaseSignal(StatusSignalType, Up),
NodeID: source,
Str: status,
} }
func (signal *ErrorSignal) MarshalBinary() ([]byte, error) {
arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena)
if err != nil {
return nil, err
} }
func NewLinkSignal(state string, id NodeID) Signal { root, err := schema.NewRootErrorSignal(seg)
return &IDStringSignal{ if err != nil {
BaseSignal: NewBaseSignal(LinkSignalType, Direct), return nil, err
NodeID: id,
Str: state,
}
} }
func NewLockSignal(state string) Signal { root.SetError(signal.Error)
return &StringSignal{
NewBaseSignal(LockSignalType, Direct), header, err := root.NewHeader()
state, if err != nil {
return nil, err
} }
err = SerializeHeader(signal.SignalHeader, header)
if err != nil {
return nil, err
} }
func (signal *StringSignal) Permission() Tree { return msg.Marshal()
return Tree{
string(signal.Type()): Tree{
signal.Str: Tree{},
},
} }
func (signal *ErrorSignal) Deserialize(ctx *Context, data []byte) error {
msg, err := capnp.Unmarshal(data)
if err != nil {
return err
} }
type ReadSignal struct { root, err := schema.ReadRootErrorSignal(msg)
BaseSignal if err != nil {
Extensions map[ExtType][]string `json:"extensions"` return err
} }
func (signal *ReadSignal) Serialize() ([]byte, error) { header, err := root.Header()
return json.Marshal(signal) if err != nil {
return err
} }
func NewReadSignal(exts map[ExtType][]string) *ReadSignal { signal.Error, err = root.Error()
return &ReadSignal{ if err != nil {
NewBaseSignal(ReadSignalType, Direct), return err
exts,
} }
signal.SignalHeader, err = DeserializeHeader(header)
if err != nil {
return err
} }
func (signal *ReadSignal) Permission() Tree { return nil
ret := Tree{}
for ext, fields := range(signal.Extensions) {
field_tree := Tree{}
for _, field := range(fields) {
field_tree[field] = Tree{}
} }
ret[ext.String()] = field_tree func (signal *ErrorSignal) Permission() Tree {
return Tree{
uint64(ErrorSignalType): nil,
} }
return Tree{ReadSignalType.String(): ret}
} }
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal {
type ReadResultSignal struct { return &ErrorSignal{
BaseSignal NewRespHeader(req_id, Direct),
NodeID NodeID fmt.Sprintf(fmt_string, args...),
NodeType NodeType }
Extensions map[ExtType]map[string]interface{} `json:"extensions"`
} }
func (signal *ReadResultSignal) Permission() Tree { type ACLTimeoutSignal struct {
SignalHeader
}
func (signal *ACLTimeoutSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *ACLTimeoutSignal) Permission() Tree {
return Tree{ return Tree{
ReadResultSignalType.String(): Tree{}, uint64(ACLTimeoutSignalType): nil,
} }
} }
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal { sig := &ACLTimeoutSignal{
return &ReadResultSignal{ NewRespHeader(req_id, Direct),
NewRespSignal(req_id, ReadResultSignalType, Direct),
node_id,
node_type,
exts,
} }
return sig
} }
type ECDHSignal struct { type StatusSignal struct {
StringSignal SignalHeader
Time time.Time Source NodeID
EDDSA ed25519.PublicKey Status string
ECDH *ecdh.PublicKey
Signature []byte
} }
func (signal *StatusSignal) Header() *SignalHeader {
type ECDHSignalJSON struct { return &signal.SignalHeader
StringSignal
Time time.Time `json:"time"`
EDDSA []byte `json:"ecdsa_pubkey"`
ECDH []byte `json:"ecdh_pubkey"`
Signature []byte `json:"signature"`
} }
func (signal *StatusSignal) Permission() Tree {
func (signal *ECDHSignal) MarshalJSON() ([]byte, error) { return Tree{
return json.Marshal(&ECDHSignalJSON{ uint64(StatusSignalType): nil,
StringSignal: signal.StringSignal,
Time: signal.Time,
ECDH: signal.ECDH.Bytes(),
EDDSA: signal.ECDH.Bytes(),
Signature: signal.Signature,
})
} }
func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
} }
func NewStatusSignal(source NodeID, status string) *StatusSignal {
func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) { return &StatusSignal{
ec_key, err := ECDH.GenerateKey(rand.Reader) NewSignalHeader(Up),
if err != nil { source,
return nil, nil, err status,
} }
now := time.Now()
time_bytes, err := now.MarshalJSON()
if err != nil {
return nil, nil, err
} }
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) type LinkSignal struct {
SignalHeader
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) NodeID
if err != nil { Action string
return nil, nil, err
} }
func (signal *LinkSignal) Header() *SignalHeader {
return &ECDHSignal{ return &signal.SignalHeader
StringSignal: StringSignal{
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "req",
},
Time: now,
EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(),
Signature: sig,
}, ec_key, nil
} }
const DEFAULT_ECDH_WINDOW = time.Second const (
LinkActionBase = "LINK_ACTION"
func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { LinkActionAdd = "ADD"
now := time.Now() )
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) func (signal *LinkSignal) Permission() Tree {
if err != nil { return Tree{
return ECDHSignal{}, nil, err uint64(LinkSignalType): Tree{
Hash(LinkActionBase, signal.Action): nil,
},
} }
ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
} }
func NewLinkSignal(action string, id NodeID) Signal {
shared_secret, err := ec_key.ECDH(req.ECDH) return &LinkSignal{
if err != nil { NewSignalHeader(Direct),
return ECDHSignal{}, nil, err id,
action,
} }
time_bytes, err := now.MarshalJSON()
if err != nil {
return ECDHSignal{}, nil, err
} }
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) type LockSignal struct {
SignalHeader
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) State string
if err != nil {
return ECDHSignal{}, nil, err
} }
func (signal *LockSignal) Header() *SignalHeader {
return ECDHSignal{ return &signal.SignalHeader
StringSignal: StringSignal{
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "resp",
},
Time: now,
EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(),
Signature: sig,
}, shared_secret, nil
} }
func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error { const (
earliest := now.Add(-window) LockStateBase = "LOCK_STATE"
latest := now.Add(window) )
if sig.Time.Compare(earliest) == -1 {
return fmt.Errorf("TIME_TOO_LATE: %+v", sig.Time)
} else if sig.Time.Compare(latest) == 1 {
return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time)
}
time_bytes, err := sig.Time.MarshalJSON() func (signal *LockSignal) Permission() Tree {
if err != nil { return Tree{
return err uint64(LockSignalType): Tree{
Hash(LockStateBase, signal.State): nil,
},
} }
sig_data := append(sig.ECDH.Bytes(), time_bytes...)
verified := ed25519.Verify(sig.EDDSA, sig_data, sig.Signature)
if verified == false {
return fmt.Errorf("Failed to verify signature")
} }
return nil func NewLockSignal(state string) *LockSignal {
return &LockSignal{
NewSignalHeader(Direct),
state,
} }
type ECDHProxySignal struct {
BaseSignal
Source NodeID
Dest NodeID
IV []byte
Data []byte
} }
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) { type ReadSignal struct {
if shared_secret == nil { SignalHeader
return nil, fmt.Errorf("need shared_secret") Extensions map[ExtType][]string `json:"extensions"`
} }
func (signal *ReadSignal) MarshalBinary() ([]byte, error) {
aes_key, err := aes.NewCipher(shared_secret[:32]) arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ser, err := SerializeSignal(signal, aes_key.BlockSize()) root, err := schema.NewRootReadSignal(seg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iv := make([]byte, aes_key.BlockSize()) header, err := root.NewHeader()
n, err := rand.Reader.Read(iv)
if err != nil { if err != nil {
return nil, err return nil, err
} else if n != len(iv) {
return nil, fmt.Errorf("Not enough bytes read for IV")
} }
encrypter := cipher.NewCBCEncrypter(aes_key, iv) err = SerializeHeader(signal.SignalHeader, header)
encrypter.CryptBlocks(ser, ser) if err != nil {
return nil, err
return &ECDHProxySignal{
BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct),
Source: source,
Dest: dest,
IV: iv,
Data: ser,
}, nil
} }
type SignalHeader struct { extensions, err := root.NewExtensions(int32(len(signal.Extensions)))
Magic uint32 if err != nil {
TypeHash uint64 return nil, err
Length uint64
} }
const SIGNAL_SER_MAGIC uint32 = 0x753a64de i := 0
const SIGNAL_SER_HEADER_LENGTH = 20 for ext_type, fields := range(signal.Extensions) {
func SerializeSignal(signal Signal, block_size int) ([]byte, error) { extension := extensions.At(i)
signal_ser, err := signal.Serialize() extension.SetType(uint64(ext_type))
f, err := extension.NewFields(int32(len(fields)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
pad_req := 0 for j, field := range(fields) {
if block_size > 0 { err := f.Set(j, field)
pad := block_size - ((SIGNAL_SER_HEADER_LENGTH + len(signal_ser)) % block_size) if err != nil {
if pad != block_size { return nil, err
pad_req = pad
} }
} }
header := SignalHeader{ i += 1
Magic: SIGNAL_SER_MAGIC,
TypeHash: Hash(signal.Type()),
Length: uint64(len(signal_ser) + pad_req),
} }
ser := make([]byte, SIGNAL_SER_HEADER_LENGTH + len(signal_ser) + pad_req) return msg.Marshal()
binary.BigEndian.PutUint32(ser[0:4], header.Magic)
binary.BigEndian.PutUint64(ser[4:12], header.TypeHash)
binary.BigEndian.PutUint64(ser[12:20], header.Length)
copy(ser[SIGNAL_SER_HEADER_LENGTH:], signal_ser)
return ser, nil
} }
func (signal *ReadSignal) Header() *SignalHeader {
func ParseSignal(ctx *Context, data []byte) (Signal, error) { return &signal.SignalHeader
if len(data) < SIGNAL_SER_HEADER_LENGTH {
return nil, fmt.Errorf("data shorter than header length")
} }
header := SignalHeader{ func (signal *ReadSignal) Permission() Tree {
Magic: binary.BigEndian.Uint32(data[0:4]), ret := Tree{}
TypeHash: binary.BigEndian.Uint64(data[4:12]), for ext, fields := range(signal.Extensions) {
Length: binary.BigEndian.Uint64(data[12:20]), field_tree := Tree{}
for _, field := range(fields) {
field_tree[Hash(FieldNameBase, field)] = nil
} }
ret[uint64(ext)] = field_tree
if header.Magic != SIGNAL_SER_MAGIC {
return nil, fmt.Errorf("signal magic mismatch 0x%x", header.Magic)
} }
return Tree{uint64(ReadSignalType): ret}
left := len(data) - SIGNAL_SER_HEADER_LENGTH
if int(header.Length) != left {
return nil, fmt.Errorf("signal length mismatch %d/%d", header.Length, left)
} }
func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
signal_def, exists := ctx.Signals[header.TypeHash] return &ReadSignal{
if exists == false { NewSignalHeader(Direct),
return nil, fmt.Errorf("0x%x is not a known signal type", header.TypeHash) exts,
} }
signal, err := signal_def.Load(ctx, data[SIGNAL_SER_HEADER_LENGTH:])
if err != nil {
return nil, err
} }
return signal, nil type ReadResultSignal struct {
SignalHeader
NodeID NodeID
NodeType NodeType
Extensions map[ExtType]map[string]SerializedValue
} }
func (signal *ReadResultSignal) Header() *SignalHeader {
func ParseECDHProxySignal(ctx *Context, signal *ECDHProxySignal, shared_secret []byte) (Signal, error) { return &signal.SignalHeader
if shared_secret == nil {
return nil, fmt.Errorf("need shared_secret")
} }
func (signal *ReadResultSignal) Permission() Tree {
aes_key, err := aes.NewCipher(shared_secret[:32]) return Tree{
if err != nil { uint64(ReadResultSignalType): nil,
return nil, err
} }
decrypter := cipher.NewCBCDecrypter(aes_key, signal.IV)
decrypted := make([]byte, len(signal.Data))
decrypter.CryptBlocks(decrypted, signal.Data)
wrapped_signal, err := ParseSignal(ctx, decrypted)
if err != nil {
return nil, err
} }
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal {
return wrapped_signal, nil return &ReadResultSignal{
NewRespHeader(req_id, Direct),
node_id,
node_type,
exts,
}
} }

@ -0,0 +1,10 @@
module github.com/mekkanized/graphvent/signal
go 1.21.0
require capnproto.org/go/capnp/v3 v3.0.0-alpha-29
require (
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect
)

@ -0,0 +1,35 @@
using Go = import "/go.capnp";
@0x83a311a90429ef94;
$Go.package("signal");
$Go.import("github.com/mekkanized/graphvent/signal");
struct SignalHeader {
direction @0 :UInt8;
id @1 :Data;
reqID @2 :Data;
}
struct ErrorSignal {
header @0 :SignalHeader;
error @1 :Text;
}
struct LinkSignal {
header @0 :SignalHeader;
action @1 :Text;
id @2 :Data;
}
struct LockSignal {
header @0 :SignalHeader;
state @1 :Text;
}
struct ReadSignal {
header @0 :SignalHeader;
extensions @1 :List(Extension);
struct Extension {
type @0 :UInt64;
fields @1 :List(Text);
}
}