Cleanup and move away from capnp to custom TLV serialization

gql_cataclysm
noah metz 2023-08-31 19:50:32 -06:00
parent 7bed89701d
commit 4daec4d601
24 changed files with 1192 additions and 1606 deletions

3
.gitignore vendored

@ -1,12 +1,15 @@
# Ignore everything
*
!/go-capnp
# But not these files...
!/.gitignore
!*.go
*.capnp.go
!go.sum
!go.mod
!*.capnp
!README.md
!LICENSE

@ -0,0 +1,9 @@
.PHONY: schema test
test: schema
clear && go test
schema: signal/signal.capnp.go
%.capnp.go: %.capnp
capnp compile -I ./go-capnp/std -ogo $<

@ -1,124 +1,127 @@
package graphvent
import (
badger "github.com/dgraph-io/badger/v3"
"fmt"
"sync"
"errors"
"runtime"
"crypto/sha512"
"crypto/ecdh"
"crypto/sha512"
"encoding/binary"
)
"errors"
"fmt"
"reflect"
"runtime"
"sync"
// A Type can be Hashed by Hash
type TypeName interface {
String() string
Prefix() string
}
badger "github.com/dgraph-io/badger/v3"
)
// Hashed a Type to a uint64
func Hash(t TypeName) uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("%s%s", t.Prefix(), t.String())))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
func Hash(base string, name string) uint64 {
digest := append([]byte(base), 0x00)
digest = append(digest, []byte(name)...)
hash := sha512.Sum512(digest)
return binary.BigEndian.Uint64(hash[0:8])
}
// NodeType identifies the 'class' of a node
type NodeType string
func (node NodeType) Prefix() string { return "NODE: " }
func (node NodeType) String() string { return string(node) }
// ExtType identifies an extension on a node
type ExtType string
func (ext ExtType) Prefix() string { return "EXTENSION: " }
func (ext ExtType) String() string { return string(ext) }
type ExtType uint64
type NodeType uint64
type SignalType uint64
type PolicyType uint64
type SerializedType uint64
//Function to load an extension from bytes
type ExtensionLoadFunc func(*Context,[]byte) (Extension, error)
func LoadExtension[T any, E interface {
*T
Extension
}](ctx *Context, data []byte) (Extension, error) {
e := E(new(T))
err := e.Deserialize(ctx, data)
if err != nil {
return nil, err
func NewExtType(name string) ExtType {
return ExtType(Hash(ExtTypeBase, name))
}
return e, nil
func NewNodeType(name string) NodeType {
return NodeType(Hash(NodeTypeBase, name))
}
type PolicyType string
func (policy PolicyType) Prefix() string { return "POLICY: " }
func (policy PolicyType) String() string { return string(policy) }
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
func NewSignalType(name string) SignalType {
return SignalType(Hash(SignalTypeBase, name))
}
return p, nil
func NewPolicyType(name string) PolicyType {
return PolicyType(Hash(PolicyTypeBase, name))
}
type PolicyInfo struct {
Load PolicyLoadFunc
Type PolicyType
func NewSerializedType(name string) SerializedType {
return SerializedType(Hash(SerializedTypeBase, name))
}
// ExtType and NodeType constants
const (
ListenerExtType = ExtType("LISTENER")
LockableExtType = ExtType("LOCKABLE")
GQLExtType = ExtType("GQL")
GroupExtType = ExtType("GROUP")
ECDHExtType = ExtType("ECDH")
GQLNodeType = NodeType("GQL")
ExtTypeBase = "ExtType"
NodeTypeBase = "NodeType"
SignalTypeBase = "SignalType"
PolicyTypeBase = "PolicyType"
SerializedTypeBase = "SerializedType"
FieldNameBase = "FieldName"
)
var (
ListenerExtType = NewExtType("LISTENER")
LockableExtType = NewExtType("LOCKABLE")
GQLExtType = NewExtType("GQL")
GroupExtType = NewExtType("GROUP")
ECDHExtType = NewExtType("ECDH")
GQLNodeType = NewNodeType("GQL")
StopSignalType = NewSignalType("STOP")
CreateSignalType = NewSignalType("CREATE")
StartSignalType = NewSignalType("START")
ErrorSignalType = NewSignalType("ERROR")
StatusSignalType = NewSignalType("STATUS")
LinkSignalType = NewSignalType("LINK")
LockSignalType = NewSignalType("LOCK")
ReadSignalType = NewSignalType("READ")
ReadResultSignalType = NewSignalType("READ_RESULT")
ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT")
MemberOfPolicyType = NewPolicyType("USER_OF")
RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF")
PerNodePolicyType = NewPolicyType("PER_NODE")
AllNodesPolicyType = NewPolicyType("ALL_NODES")
StructType = NewSerializedType("struct")
SliceType = NewSerializedType("slice")
ArrayType = NewSerializedType("array")
PointerType = NewSerializedType("pointer")
MapType = NewSerializedType("map")
ErrorType = NewSerializedType("error")
ExtensionType = NewSerializedType("extension")
StringType = NewSerializedType("string")
NodeKeyType = NewSerializedType("node_key")
NodeNotFoundError = errors.New("Node not found in DB")
ECDH = ecdh.X25519()
)
type SignalLoadFunc func(*Context,[]byte) (Signal, error)
func LoadSignal[T any, S interface{
*T
Signal
}](ctx *Context, data []byte) (Signal, error) {
s := S(new(T))
err := s.Deserialize(ctx, data)
if err != nil {
return nil, err
type ExtensionInfo struct {
Type reflect.Type
Data interface{}
}
return s, nil
type NodeInfo struct {
Extensions []ExtType
}
type SignalInfo struct {
Load SignalLoadFunc
Type SignalType
type TypeSerialize func(*Context,interface{}) ([]byte, error)
type TypeDeserialize func(*Context,[]byte) (interface{}, error)
type TypeInfo struct {
Type reflect.Type
Serialize TypeSerialize
Deserialize TypeDeserialize
}
// Information about a registered extension
type ExtensionInfo struct {
// Function used to load extensions of this type from the database
Load ExtensionLoadFunc
Type ExtType
// Extra context data shared between nodes of this class
Data interface{}
type Int int
func (i Int) MarshalBinary() ([]byte, error) {
ret := make([]byte, 8)
binary.BigEndian.PutUint64(ret, uint64(i))
return ret, nil
}
// Information about a registered node type
type NodeInfo struct {
Type NodeType
// Required extensions to be a valid node of this class
Extensions []ExtType
type String string
func (str String) MarshalBinary() ([]byte, error) {
return []byte(str), nil
}
// A Context stores all the data to run a graphvent process
@ -128,101 +131,132 @@ type Context struct {
// Logging interface
Log Logger
// Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo
Extensions map[ExtType]ExtensionInfo
ExtensionTypes map[reflect.Type]ExtType
// Map between databse policy hashes and the registered info
Policies map[uint64]PolicyInfo
Policies map[PolicyType]reflect.Type
PolicyTypes map[reflect.Type]PolicyType
// Map between serialized signal hashes and the registered info
Signals map[uint64]SignalInfo
Signals map[SignalType]reflect.Type
SignalTypes map[reflect.Type]SignalType
// Map between database type hashes and the registered info
Types map[uint64]*NodeInfo
Nodes map[NodeType]NodeInfo
// Map between go types and registered info
Types map[SerializedType]TypeInfo
TypeReflects map[reflect.Type]SerializedType
// Routing map to all the nodes local to this context
NodesLock sync.RWMutex
Nodes map[NodeID]*Node
nodeMapLock sync.RWMutex
nodeMap map[NodeID]*Node
}
// Register a NodeType to the context, with the list of extensions it requires
func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
type_hash := Hash(node_type)
_, exists := ctx.Types[type_hash]
_, exists := ctx.Nodes[node_type]
if exists == true {
return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type)
return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type)
}
ext_found := map[ExtType]bool{}
for _, extension := range(extensions) {
_, in_ctx := ctx.Extensions[Hash(extension)]
_, in_ctx := ctx.Extensions[extension]
if in_ctx == false {
return fmt.Errorf("Cannot register node type %s, required extension %s not in context", node_type, extension)
return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension)
}
_, duplicate := ext_found[extension]
if duplicate == true {
return fmt.Errorf("Duplicate extension %s found in extension list", extension)
return fmt.Errorf("Duplicate extension %+v found in extension list", extension)
}
ext_found[extension] = true
}
ctx.Types[type_hash] = &NodeInfo{
Type: node_type,
ctx.Nodes[node_type] = NodeInfo{
Extensions: extensions,
}
return nil
}
func RegisterSignal[T any, S interface {
*T
Signal
}](ctx *Context, signal_type SignalType) error {
type_hash := Hash(signal_type)
_, exists := ctx.Signals[type_hash]
func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type PolicyType) error {
_, exists := ctx.Policies[policy_type]
if exists == true {
return fmt.Errorf("Cannot register signal of type %s, type already exists in context", signal_type)
return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type)
}
ctx.Signals[type_hash] = SignalInfo{
Load: LoadSignal[T, S],
Type: signal_type,
ctx.Policies[policy_type] = reflect_type
ctx.PolicyTypes[reflect_type] = policy_type
return nil
}
func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalType) error {
_, exists := ctx.Signals[signal_type]
if exists == true {
return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type)
}
ctx.Signals[signal_type] = reflect_type
ctx.SignalTypes[reflect_type] = signal_type
return nil
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context
func RegisterExtension[T any, E interface{
*T
Extension
}](ctx *Context, data interface{}) error {
var zero E
ext_type := zero.Type()
type_hash := Hash(ext_type)
_, exists := ctx.Extensions[type_hash]
func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType, data interface{}) error {
_, exists := ctx.Extensions[ext_type]
if exists == true {
return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type)
return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type)
}
ctx.Extensions[type_hash] = ExtensionInfo{
Load: LoadExtension[T,E],
Type: ext_type,
ctx.Extensions[ext_type] = ExtensionInfo{
Type: reflect_type,
Data: data,
}
ctx.ExtensionTypes[reflect_type] = ext_type
return nil
}
func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error {
_, exists := ctx.Types[ctx_type]
if exists == true {
return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type)
}
_, exists = ctx.TypeReflects[reflect_type]
if exists == true {
return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type)
}
if deserialize == nil {
return fmt.Errorf("Cannot register field without deserialize function")
}
if serialize == nil {
return fmt.Errorf("Cannot register field without serialize function")
}
ctx.Types[ctx_type] = TypeInfo{
Type: reflect_type,
Serialize: serialize,
Deserialize: deserialize,
}
ctx.TypeReflects[reflect_type] = ctx_type
return nil
}
func (ctx *Context) AddNode(id NodeID, node *Node) {
ctx.NodesLock.Lock()
ctx.Nodes[id] = node
ctx.NodesLock.Unlock()
ctx.nodeMapLock.Lock()
ctx.nodeMap[id] = node
ctx.nodeMapLock.Unlock()
}
func (ctx *Context) Node(id NodeID) (*Node, bool) {
ctx.NodesLock.RLock()
node, exists := ctx.Nodes[id]
ctx.NodesLock.RUnlock()
ctx.nodeMapLock.RLock()
node, exists := ctx.nodeMap[id]
ctx.nodeMapLock.RUnlock()
return node, exists
}
// Get a node from the context, or load from the database if not loaded
func (ctx *Context) GetNode(id NodeID) (*Node, error) {
func (ctx *Context) getNode(id NodeID) (*Node, error) {
target, exists := ctx.Node(id)
if exists == false {
@ -241,7 +275,7 @@ func (ctx *Context) Send(messages Messages) error {
if msg.Dest == ZeroID {
panic("Can't send to null ID")
}
target, err := ctx.GetNode(msg.Dest)
target, err := ctx.getNode(msg.Dest)
if err == nil {
select {
case target.MsgChan <- msg:
@ -262,55 +296,311 @@ func (ctx *Context) Send(messages Messages) error {
return nil
}
type defaultKind struct {
Type SerializedType
Serialize func(interface{})([]byte, error)
Deserialize func([]byte)(interface{}, error)
}
var defaultKinds = map[reflect.Kind]defaultKind{
reflect.Int: {
Deserialize: func(data []byte)(interface{}, error){
if len(data) != 8 {
return nil, fmt.Errorf("invalid length: %d/8", len(data))
}
return int(binary.BigEndian.Uint64(data)), nil
},
Serialize: func(val interface{})([]byte, error){
i, ok := val.(int)
if ok == false {
return nil, fmt.Errorf("invalid type %+v", reflect.TypeOf(val))
} else {
bytes := make([]byte, 8)
binary.BigEndian.PutUint64(bytes, uint64(i))
return bytes, nil
}
},
},
}
type SerializedValue struct {
TypeStack []uint64
Data []byte
}
func (field SerializedValue) MarshalBinary() ([]byte, error) {
data := []byte{}
for _, t := range(field.TypeStack) {
t_ser := make([]byte, 8)
binary.BigEndian.PutUint64(t_ser, uint64(t))
data = append(data, t_ser...)
}
data = append(data, field.Data...)
return data, nil
}
func RecurseTypes(ctx *Context, t reflect.Type) ([]uint64, []reflect.Kind, error) {
var ctx_type uint64 = 0x00
ctype, exists := ctx.TypeReflects[t]
if exists == true {
ctx_type = uint64(ctype)
}
var new_types []uint64
var new_kinds []reflect.Kind
kind := t.Kind()
switch kind {
case reflect.Array:
if ctx_type == 0x00 {
ctx_type = uint64(ArrayType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Array)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Map:
if ctx_type == 0x00 {
ctx_type = uint64(MapType)
}
key_types, key_kinds, err := RecurseTypes(ctx, t.Key())
if err != nil {
return nil, nil, err
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, key_types...)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Map)
new_kinds = append(new_kinds, key_kinds...)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Slice:
if ctx_type == 0x00 {
ctx_type = uint64(SliceType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Slice)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.Pointer:
if ctx_type == 0x00 {
ctx_type = uint64(PointerType)
}
elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem())
if err != nil {
return nil, nil, err
}
new_types = append(new_types, ctx_type)
new_types = append(new_types, elem_types...)
new_kinds = append(new_kinds, reflect.Pointer)
new_kinds = append(new_kinds, elem_kinds...)
case reflect.String:
if ctx_type == 0x00 {
ctx_type = uint64(StringType)
}
new_types = append(new_types, ctx_type)
new_kinds = append(new_kinds, reflect.String)
default:
return nil, nil, fmt.Errorf("unhandled kind: %+v - %+v", kind, t)
}
return new_types, new_kinds, nil
}
func serializeValue(ctx *Context, kind_stack []reflect.Kind, value reflect.Value) ([]byte, error) {
kind := kind_stack[len(kind_stack) - 1]
switch kind {
default:
return nil, fmt.Errorf("unhandled kind: %+v", kind)
}
}
func SerializeValue(ctx *Context, value reflect.Value) (SerializedValue, error) {
if value.IsValid() == false {
return SerializedValue{}, fmt.Errorf("Cannot serialize invalid value: %+v", value)
}
type_stack, kind_stack, err := RecurseTypes(ctx, value.Type())
if err != nil {
return SerializedValue{}, err
}
bytes, err := serializeValue(ctx, kind_stack, value)
if err != nil {
return SerializedValue{}, err
}
return SerializedValue{
type_stack,
bytes,
}, nil
}
/*
default:
kind_def, handled := defaultKinds[kind]
if handled == false {
ctx_type, handled := ctx.TypeReflects[value.Type()]
if handled == false {
err = fmt.Errorf("%+v is not a handled reflect type", value.Type())
break
}
type_info, handled := ctx.Types[ctx_type]
if handled == false {
err = fmt.Errorf("%+v is not a handled reflect type(INTERNAL_ERROR)", value.Type())
break
}
field_ser, err := type_info.Serialize(ctx, value.Interface())
if err != nil {
err = fmt.Errorf(err.Error())
break
}
ret = SerializedValue{
[]uint64{uint64(ctx_type)},
field_ser,
}
}
field_ser, err := kind_def.Serialize(value.Interface())
if err != nil {
err = fmt.Errorf(err.Error())
} else {
ret = SerializedValue{
[]uint64{uint64(kind_def.Type)},
field_ser,
}
}
*/
func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) {
if ext == nil {
return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension")
}
ext_value := reflect.ValueOf(ext).Elem()
field := ext_value.FieldByName(field_name)
if field.IsValid() == false {
return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext)
} else {
return SerializeValue(ctx, field)
}
}
func SerializeSignal(ctx *Context, signal Signal, ctx_type SignalType) (SerializedValue, error) {
return SerializedValue{}, nil
}
func SerializeExtension(ctx *Context, ext Extension, ctx_type ExtType) (SerializedValue, error) {
if ext == nil {
return SerializedValue{}, fmt.Errorf("Cannot serialize nil Extension ")
}
ext_type := reflect.TypeOf(ext).Elem()
ext_value := reflect.ValueOf(ext).Elem()
m := map[string]SerializedValue{}
for _, field := range(reflect.VisibleFields(ext_type)) {
ext_tag, tagged_ext := field.Tag.Lookup("ext")
if tagged_ext == false {
continue
} else {
field_value := ext_value.FieldByIndex(field.Index)
var err error
m[ext_tag], err = SerializeValue(ctx, field_value)
if err != nil {
return SerializedValue{}, err
}
}
}
map_value := reflect.ValueOf(m)
map_ser, err := SerializeValue(ctx, map_value)
if err != nil {
return SerializedValue{}, err
}
return SerializedValue{
append([]uint64{uint64(ctx_type)}, map_ser.TypeStack...),
map_ser.Data,
}, nil
}
func DeserializeValue(ctx *Context, value SerializedValue) (interface{}, error) {
// TODO: do the opposite of SerializeValue.
// 1) Check the type to handle special types(array, list, map, pointer)
// 2) Check if the type is registered in the context, handle if so
// 3) Check if the type is a default type, handle if so
// 4) Return error if we don't know how to deserialize the type
return nil, fmt.Errorf("Undefined")
}
// Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{
DB: db,
Log: log,
Extensions: map[uint64]ExtensionInfo{},
Types: map[uint64]*NodeInfo{},
Signals: map[uint64]SignalInfo{},
Nodes: map[NodeID]*Node{},
Policies: map[PolicyType]reflect.Type{},
PolicyTypes: map[reflect.Type]PolicyType{},
Extensions: map[ExtType]ExtensionInfo{},
ExtensionTypes: map[reflect.Type]ExtType{},
Signals: map[SignalType]reflect.Type{},
SignalTypes: map[reflect.Type]SignalType{},
Nodes: map[NodeType]NodeInfo{},
nodeMap: map[NodeID]*Node{},
Types: map[SerializedType]TypeInfo{},
TypeReflects: map[reflect.Type]SerializedType{},
}
var err error
err = RegisterExtension[LockableExt,*LockableExt](ctx, nil)
err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil)
err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil)
err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[GroupExt,*GroupExt](ctx, nil)
gql_ctx := NewGQLExtContext()
err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx)
if err != nil {
return nil, err
}
gql_ctx := NewGQLExtContext()
err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx)
err = ctx.RegisterSignal(reflect.TypeOf((*StopSignal)(nil)), StopSignalType)
if err != nil {
return nil, err
}
err = ctx.RegisterSignal(reflect.TypeOf((*CreateSignal)(nil)), CreateSignalType)
if err != nil {
return nil, err
}
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType)
err = ctx.RegisterSignal(reflect.TypeOf((*StartSignal)(nil)), StartSignalType)
if err != nil {
return nil, err
}
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, NewSignalType)
err = ctx.RegisterSignal(reflect.TypeOf((*ReadSignal)(nil)), ReadSignalType)
if err != nil {
return nil, err
}
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StartSignalType)
err = ctx.RegisterSignal(reflect.TypeOf((*ReadResultSignal)(nil)), ReadResultSignalType)
if err != nil {
return nil, err
}

@ -1,167 +0,0 @@
package graphvent
import (
"fmt"
"time"
"encoding/json"
"crypto/ecdsa"
"crypto/x509"
"crypto/ecdh"
)
type ECDHState struct {
ECKey *ecdh.PrivateKey
SharedSecret []byte
}
type ECDHStateJSON struct {
ECKey []byte `json:"ec_key"`
SharedSecret []byte `json:"shared_secret"`
}
func (state *ECDHState) MarshalJSON() ([]byte, error) {
var key_bytes []byte
var err error
if state.ECKey != nil {
key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey)
if err != nil {
return nil, err
}
}
return json.Marshal(&ECDHStateJSON{
ECKey: key_bytes,
SharedSecret: state.SharedSecret,
})
}
func (state *ECDHState) UnmarshalJSON(data []byte) error {
var j ECDHStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
state.SharedSecret = j.SharedSecret
if len(j.ECKey) == 0 {
state.ECKey = nil
} else {
tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey)
if err != nil {
return err
}
ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey)
if ok == false {
return fmt.Errorf("Parsed wrong key type from DB for ECDHState")
}
state.ECKey, err = ecdsa_key.ECDH()
if err != nil {
return err
}
}
return nil
}
type ECDHMap map[NodeID]ECDHState
func (m ECDHMap) MarshalJSON() ([]byte, error) {
tmp := map[string]ECDHState{}
for id, state := range(m) {
tmp[id.String()] = state
}
return json.Marshal(tmp)
}
type ECDHExt struct {
ECDHStates ECDHMap
}
func NewECDHExt() *ECDHExt {
return &ECDHExt{
ECDHStates: ECDHMap{},
}
}
func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} {
var zero T
field_func, ok := field_funcs[name]
if ok == false {
return fmt.Errorf("%s is not a field of %s", name, zero.Type())
}
return field_func(t)
}
func (ext *ECDHExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{
"ecdh_states": func(ext *ECDHExt) interface{} {
return ext.ECDHStates
},
})
}
func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) Messages {
source := KeyID(signal.EDDSA)
messages := Messages{}
switch signal.Str {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
messages = messages.Add(node.ID, node.Key, &resp, source)
} else {
log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source)
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source)
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
shared_secret, err := state.ECKey.ECDH(signal.ECDH)
if err == nil {
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
log.Logf("ecdh", "unknown echd state %s", signal.Str)
}
return messages
}
func (ext *ECDHExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
switch signal.Type() {
case ECDHSignalType:
sig := signal.(*ECDHSignal)
return ext.HandleECDHSignal(ctx.Log, node, sig)
}
return nil
}
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.Marshal(ext)
}
func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext)
}

@ -1,6 +1,8 @@
module github.com/mekkanized/graphvent
go 1.20
go 1.21.0
replace github.com/mekkanized/graphvent/signal v0.0.0 => ./signal
require (
github.com/dgraph-io/badger/v3 v3.2103.5
@ -11,6 +13,7 @@ require (
)
require (
capnproto.org/go/capnp/v3 v3.0.0-alpha-29 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgraph-io/badger/v4 v4.1.0 // indirect
@ -28,8 +31,11 @@ require (
github.com/klauspost/compress v1.12.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mekkanized/graphvent/signal v0.0.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
go.opencensus.io v0.22.5 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect
golang.org/x/sys v0.6.0 // indirect
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect
)

@ -1,3 +1,5 @@
capnproto.org/go/capnp/v3 v3.0.0-alpha-29 h1:ICLhiy4Jmp0d7hLQO+HzFAVIft/oxpPAUPV8tqx+eUE=
capnproto.org/go/capnp/v3 v3.0.0-alpha-29/go.mod h1:+ysMHvOh1EWNOyorxJWs1omhRFiDoKxKkWQACp54jKM=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
@ -134,6 +136,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -171,3 +174,5 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 h1:yksDCGMVzyn3vlyf0GZ3huiF5FFaMGQpQ3UJvR0EoGA=
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5/go.mod h1:1LtNdPAs8WH+BTcQiZAOo2MIKD/5jyK/u7sZ9ZPe5SE=

@ -205,7 +205,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes)
}
auth_id := NodeID{auth_uuid}
auth_id := NodeID(auth_uuid)
key_bytes, err := base64.StdEncoding.DecodeString(key_b64)
if err != nil {
@ -234,7 +234,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{},
Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext),
Server: server,
User: key_id,
Key: key,
@ -270,7 +270,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
query := GQLPayload{}
json.Unmarshal(str, &query)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext)
gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
@ -401,7 +401,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
}
} else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext)
gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
Context: req_ctx,
@ -543,7 +543,7 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
return graphql.NewSchema(schemaConfig)
}
func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, T)(interface{}, error)) error {
func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, SerializedValue)(interface{}, error)) error {
if ctx == nil {
return fmt.Errorf("ctx is nil")
}
@ -561,21 +561,19 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st
return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) {
ext, exists := result.Result.Extensions[ext_type]
if exists == false {
return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type)
return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type)
}
val_if, exists := ext[acl_name]
val_ser, exists := ext[acl_name]
if exists == false {
return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type)
return nil, fmt.Errorf("%s is not in the fields of %+v in the result", acl_name, ext_type)
}
var zero T
val, ok := val_if.(T)
if ok == false {
return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero))
if val_ser.TypeStack[0] == uint64(ErrorType) {
return nil, fmt.Errorf(string(val_ser.Data))
}
return resolve_fn(p, val)
return resolve_fn(p, val_ser)
})
}
@ -681,8 +679,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
for field_name, field := range(self_fields) {
self_field := field
err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, val interface{})(interface{}, error) {
err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
@ -715,7 +713,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
for field_name, field := range(list_fields) {
list_field := field
resolve_fn := func(p graphql.ResolveParams, val interface{})(interface{}, error) {
resolve_fn := func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
@ -736,7 +734,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
return nodes, nil
}
err := RegisterField(ctx, ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil {
return err
}
@ -764,7 +762,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte
_, exists := ctx.NodeTypes[node_type]
if exists == true {
return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type)
return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type)
}
node_interfaces, err := GQLInterfaces(ctx, interface_names)
@ -830,19 +828,20 @@ func NewGQLExtContext() *GQLExtContext {
panic(err)
}
err = RegisterField(&context, context.Interfaces["Node"].List, "Members", GroupExtType, "members",
func(p graphql.ResolveParams, val map[NodeID]string)(interface{}, error) {
err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members",
func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node_list := make([]NodeID, len(val))
/*node_list := make([]NodeID, len(val))
i := 0
for id, _ := range(val) {
node_list[i] = id
i += 1
}
*/
node_list := []NodeID{}
nodes, err := ResolveNodes(ctx, p, node_list)
if err != nil {
return nil, err
@ -895,7 +894,7 @@ func NewGQLExtContext() *GQLExtContext {
panic(err)
}
err = RegisterField(&context, graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen string) (interface{}, error) {
err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen SerializedValue) (interface{}, error) {
return listen, nil
})
if err != nil {
@ -1000,14 +999,6 @@ type GQLExt struct {
Listen string `json:"listen"`
}
func (ext *GQLExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*GQLExt)interface{}{
"listen": func(ext *GQLExt) interface{} {
return ext.Listen
},
})
}
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id]
@ -1036,11 +1027,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
// Process ReadResultSignalType by forwarding it to the waiting resolver
messages := Messages{}
if signal.Type() == ErrorSignalType {
switch sig := signal.(type) {
case *ErrorSignal:
// TODO: Forward to resolver if waiting for it
sig := signal.(*ErrorSignal)
response_chan := ext.FreeResponseChannel(sig.ReqID())
response_chan := ext.FreeResponseChannel(sig.Header().ReqID)
if response_chan != nil {
select {
case response_chan <- sig:
@ -1052,9 +1042,8 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else {
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
}
} else if signal.Type() == ReadResultSignalType {
sig := signal.(*ReadResultSignal)
response_chan := ext.FindResponseChannel(sig.ReqID())
case *ReadResultSignal:
response_chan := ext.FindResponseChannel(sig.ReqID)
if response_chan != nil {
select {
case response_chan <- sig:
@ -1065,23 +1054,23 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
}
} else if signal.Type() == StartSignalType {
case *StartSignal:
ctx.Log.Logf("gql", "starting gql server %s", node.ID)
err := ext.StartGQLServer(ctx, node)
if err == nil {
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
node.QueueSignal(time.Now(), NewStatusSignal(node.ID, "server_started"))
} else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
}
}
return messages
return nil
}
func (ext *GQLExt) Type() ExtType {
return GQLExtType
}
func (ext *GQLExt) Serialize() ([]byte, error) {
func (ext *GQLExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext)
}

@ -36,7 +36,7 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy
return false
}
node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)]
node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
if exists == false {
return false
} else {
@ -73,7 +73,7 @@ func NodeInterfaceResolveType(required_extensions []ExtType, default_type **grap
gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false {
node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)]
node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
if exists == false {
return nil
} else {

@ -51,16 +51,16 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields)
msgs := Messages{}
msgs = msgs.Add(ctx.Server.ID, ctx.Key, read_signal, id)
msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID)
resp_channels[read_signal.ID] = response_chan
node_ids[read_signal.ID] = id
// TODO: Send all at once instead of createing Messages for each
err = ctx.Context.Send(msgs)
if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID())
ctx.Ext.FreeResponseChannel(read_signal.ID)
return nil, err
}
}
@ -68,8 +68,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
responses := []NodeResult{}
for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout
response, err := WaitForSignal(response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{
return sig.ReqID() == sig_id
response, err := WaitForSignal(response_chan, time.Millisecond*100, func(sig *ReadResultSignal)bool{
return sig.ReqID == sig_id
})
if err != nil {
return nil, err

@ -81,6 +81,6 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) {
return Hash(node.Result.NodeType), nil
return uint64(node.Result.NodeType), nil
})
}

@ -21,7 +21,7 @@ import (
func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"})
TestNodeType := NodeType("TEST")
TestNodeType := NewNodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
fatalErr(t, err)
@ -30,33 +30,33 @@ func TestGQLServer(t *testing.T) {
gql_id := KeyID(pub)
group_policy_1 := NewAllNodesPolicy(Tree{
ReadSignalType.String(): Tree{
GroupExtType.String(): Tree{
"members": Tree{},
uint64(ReadSignalType): Tree{
uint64(GroupExtType): Tree{
Hash(FieldNameBase, "members"): Tree{},
},
},
ReadResultSignalType.String(): nil,
ErrorSignalType.String(): nil,
uint64(ReadResultSignalType): nil,
uint64(ErrorSignalType): nil,
})
group_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{
gql_id: Tree{
LinkSignalType.String(): nil,
LockSignalType.String(): nil,
StatusSignalType.String(): nil,
ReadSignalType.String(): nil,
uint64(LinkSignalType): nil,
uint64(LockSignalType): nil,
uint64(StatusSignalType): nil,
uint64(ReadSignalType): nil,
},
})
user_policy_1 := NewAllNodesPolicy(Tree{
ReadResultSignalType.String(): nil,
ErrorSignalType.String(): nil,
uint64(ReadResultSignalType): nil,
uint64(ErrorSignalType): nil,
})
user_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{
gql_id: Tree{
LinkSignalType.String(): nil,
ReadSignalType.String(): nil,
uint64(LinkSignalType): nil,
uint64(ReadSignalType): nil,
},
})
@ -80,8 +80,8 @@ func TestGQLServer(t *testing.T) {
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started"
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Status == "server_started"
})
fatalErr(t, err)
@ -107,7 +107,9 @@ func TestGQLServer(t *testing.T) {
},
}
auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize())
n1_id_bytes, err := n1.ID.MarshalBinary()
fatalErr(t, err)
auth_username := base64.StdEncoding.EncodeToString(n1_id_bytes)
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err)
auth_password := base64.StdEncoding.EncodeToString(key_bytes)
@ -196,11 +198,11 @@ func TestGQLServer(t *testing.T) {
SubGQL(sub_1)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID)
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped"
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Status == "stopped"
})
fatalErr(t, err)
}
@ -208,7 +210,7 @@ func TestGQLServer(t *testing.T) {
func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
TestUserNodeType := NodeType("TEST_USER")
TestUserNodeType := NewNodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
fatalErr(t, err)
u1 := NewNode(ctx, nil, TestUserNodeType, 10, nil)
@ -225,33 +227,31 @@ func TestGQLDB(t *testing.T) {
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID)
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Status == "stopped" && sig.Source == gql.ID
})
fatalErr(t, err)
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
ser3, err := StopSignal.Serialize()
ser1, err := gql.Serialize(ctx)
ser2, err := u1.Serialize(ctx)
ctx.Log.Logf("test", "SER_1: \n%s\n\n", ser1)
ctx.Log.Logf("test", "SER_2: \n%s\n\n", ser2)
ctx.Log.Logf("test", "SER_3: \n%s\n\n", ser3)
// Clear all loaded nodes from the context so it loads them from the database
ctx.Nodes = map[NodeID]*Node{}
ctx.nodeMap = map[NodeID]*Node{}
gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded)
listener_ext, err = GetExt[*ListenerExt](gql_loaded, GQLExtType)
fatalErr(t, err)
msgs = Messages{}
msgs = msgs.Add(gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID)
msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool {
return sig.Status == "stopped" && sig.Source == gql_loaded.ID
})
fatalErr(t, err)
}

@ -16,10 +16,10 @@ func AddNodeFields(object *graphql.Object) {
})
}
func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type {
func NewGQLNodeType(gql_name string, node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type {
var gql Type
gql.Type = graphql.NewObject(graphql.ObjectConfig{
Name: string(node_type),
Name: gql_name,
Interfaces: interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(NodeResult)

@ -6,7 +6,7 @@ import (
badger "github.com/dgraph-io/badger/v3"
)
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER")
var SimpleListenerNodeType = NewNodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
listener_extension := NewListenerExt(buffer)

@ -12,18 +12,10 @@ func (ext *GroupExt) Type() ExtType {
return GroupExtType
}
func (ext *GroupExt) Serialize() ([]byte, error) {
func (ext *GroupExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext)
}
func (ext *GroupExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*GroupExt)interface{}{
"members": func(ext *GroupExt) interface{} {
return ext.Members
},
})
}
func NewGroupExt(members map[NodeID]string) *GroupExt {
if members == nil {
members = map[NodeID]string{}

@ -18,19 +18,8 @@ func NewListenerExt(buffer int) *ListenerExt {
}
}
func (ext *ListenerExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ListenerExt)interface{}{
"buffer": func(ext *ListenerExt) interface{} {
return ext.Buffer
},
"chan": func(ext *ListenerExt) interface{} {
return ext.Chan
},
})
}
// Simple load function, unmarshal the buffer int from json
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
func (ext *ListenerExt) DeserializeListenerExt(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Signal, ext.Buffer)
return err
@ -51,6 +40,6 @@ func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal
return nil
}
func (ext *ListenerExt) Serialize() ([]byte, error) {
func (ext *ListenerExt) MarshalBinary() ([]byte, error) {
return json.Marshal(ext.Buffer)
}

@ -1,7 +1,6 @@
package graphvent
import (
"encoding/binary"
"github.com/google/uuid"
)
@ -15,119 +14,17 @@ const (
)
type LockableExt struct{
State ReqState
ReqID uuid.UUID
Owner *NodeID
PendingOwner *NodeID
Requirements map[NodeID]ReqState
}
func (ext *LockableExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*LockableExt)interface{}{
"owner": func(ext *LockableExt) interface{} {
return ext.Owner
},
"pending_owner": func(ext *LockableExt) interface{} {
return ext.PendingOwner
},
"requirements": func(ext *LockableExt) interface{} {
return ext.Requirements
},
})
State ReqState `ext:""`
ReqID *uuid.UUID `ext:""`
Owner *NodeID `ext:""`
PendingOwner *NodeID `ext:""`
Requirements map[NodeID]ReqState `ext:""`
}
func (ext *LockableExt) Type() ExtType {
return LockableExtType
}
func (ext *LockableExt) Serialize() ([]byte, error) {
ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements)))
if ext.Owner != nil {
bytes, err := ext.Owner.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[0:16], bytes)
}
if ext.PendingOwner != nil {
bytes, err := ext.PendingOwner.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[16:32], bytes)
}
binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements)))
ret[40] = byte(ext.State)
cur := 41
for req, state := range(ext.Requirements) {
bytes, err := req.MarshalBinary()
if err != nil {
return nil, err
}
copy(ret[cur:cur+16], bytes)
ret[cur+16] = byte(state)
cur += 17
}
return ret, nil
}
func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
cur := 0
all_zero := true
for _, b := range(data[cur:cur+16]) {
if all_zero == true && b != 0x00 {
all_zero = false
}
}
if all_zero == false {
tmp, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
ext.Owner = &tmp
}
cur += 16
all_zero = true
for _, b := range(data[cur:cur+16]) {
if all_zero == true && b != 0x00 {
all_zero = false
}
}
if all_zero == false {
tmp, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
ext.PendingOwner = &tmp
}
cur += 16
num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8]))
cur += 8
ext.State = ReqState(data[cur])
cur += 1
if num_requirements != 0 {
ext.Requirements = map[NodeID]ReqState{}
}
for i := 0; i < num_requirements; i++ {
id, err := IDFromBytes(data[cur:cur+16])
if err != nil {
return err
}
cur += 16
state := ReqState(data[cur])
cur += 1
ext.Requirements[id] = state
}
return nil
}
func NewLockableExt(requirements []NodeID) *LockableExt {
var reqs map[NodeID]ReqState = nil
if requirements != nil {
@ -148,21 +45,21 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{}
signal := NewLockSignal("unlock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target)
return signal.Header().ID, ctx.Send(msgs)
}
// Send the signal to lock a node from itself
func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{}
signal := NewLockSignal("lock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target)
return signal.Header().ID, ctx.Send(msgs)
}
func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages {
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) Messages {
str := signal.Error
log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str)
ctx.Log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str)
msgs := Messages {}
switch str {
@ -173,7 +70,7 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID,
for id, state := range(ext.Requirements) {
if state == Locked {
ext.Requirements[id] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id)
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), id)
}
}
}
@ -185,51 +82,48 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID,
return msgs
}
func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages {
id := signal.NodeID
action := signal.Str
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) Messages {
msgs := Messages {}
if ext.State == Unlocked {
switch action {
switch signal.Action {
case "add":
_, exists := ext.Requirements[id]
_, exists := ext.Requirements[signal.NodeID]
if exists == true {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "already_requirement"), source)
} else {
if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{}
}
ext.Requirements[id] = Unlocked
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_added"), source)
ext.Requirements[signal.NodeID] = Unlocked
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_added"), source)
}
case "remove":
_, exists := ext.Requirements[id]
_, exists := ext.Requirements[signal.NodeID]
if exists == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else {
delete(ext.Requirements, id)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_removed"), source)
delete(ext.Requirements, signal.NodeID)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_removed"), source)
}
default:
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "unknown_action"), source)
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source)
}
return msgs
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state)
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) Messages {
ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal.State)
msgs := Messages{}
switch state {
switch signal.State {
case "locked":
state, found := ext.Requirements[source]
if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else if state == Locking {
if ext.State == Locking {
ext.Requirements[source] = Locked
@ -245,19 +139,19 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
if locked == reqs {
ext.State = Locked
ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
} else {
log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
}
} else if ext.State == AbortingLock {
ext.Requirements[source] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), source)
}
}
case "unlocked":
state, found := ext.Requirements[source]
if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source)
} else if state == Unlocking {
ext.Requirements[source] = Unlocked
reqs := 0
@ -274,13 +168,14 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
ext.State = Unlocked
if old_state == Unlocking {
ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner)
ext.ReqID = nil
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner)
} else if old_state == AbortingLock {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner)
msgs = msgs.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner)
ext.PendingOwner = ext.Owner
}
} else {
log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs)
ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs)
}
}
case "lock":
@ -290,23 +185,24 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
new_owner := source
ext.PendingOwner = &new_owner
ext.Owner = &new_owner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner)
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), new_owner)
} else {
ext.State = Locking
ext.ReqID = signal.ID()
id := signal.ID
ext.ReqID = &id
new_owner := source
ext.PendingOwner = &new_owner
for id, state := range(ext.Requirements) {
if state != Unlocked {
log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
}
ext.Requirements[id] = Locking
lock_signal := NewLockSignal("lock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id)
}
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source)
}
case "unlock":
if ext.State == Locked {
@ -315,25 +211,26 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
new_owner := source
ext.PendingOwner = nil
ext.Owner = nil
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner)
msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), new_owner)
} else if source == *ext.Owner {
ext.State = Unlocking
ext.ReqID = signal.ID()
id := signal.ID
ext.ReqID = &id
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
if state != Locked {
log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
}
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id)
}
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_locked"), source)
}
default:
log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
}
return msgs
}
@ -342,25 +239,25 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
messages := Messages{}
switch signal.Direction() {
switch signal.Header().Direction {
case Up:
if ext.Owner != nil {
if *ext.Owner != node.ID {
messages = messages.Add(node.ID, node.Key, signal, *ext.Owner)
messages = messages.Add(ctx, node.ID, node.Key, signal, *ext.Owner)
}
}
case Down:
for requirement, _ := range(ext.Requirements) {
messages = messages.Add(node.ID, node.Key, signal, requirement)
messages = messages.Add(ctx, node.ID, node.Key, signal, requirement)
}
case Direct:
switch signal.Type() {
case LinkSignalType:
messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*IDStringSignal))
case LockSignalType:
messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal))
case ErrorSignalType:
messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal))
switch sig := signal.(type) {
case *LinkSignal:
messages = ext.HandleLinkSignal(ctx, node, source, sig)
case *LockSignal:
messages = ext.HandleLockSignal(ctx, node, source, sig)
case *ErrorSignal:
messages = ext.HandleErrorSignal(ctx, node, source, sig)
default:
}
default:

@ -7,7 +7,7 @@ import (
"crypto/rand"
)
const TestLockableType = NodeType("TEST_LOCKABLE")
var TestLockableType = NewNodeType("TEST_LOCKABLE")
func lockableTestContext(t *testing.T, logs []string) *Context {
ctx := logTestContext(t, logs)
@ -43,57 +43,24 @@ func TestLink(t *testing.T) {
)
msgs := Messages{}
msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID)
msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.Error == "req_added"
})
fatalErr(t, err)
msgs = Messages{}
s := NewBaseSignal("TEST", Down)
msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID)
msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
_, err = WaitForSignal(l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
msgs = Messages{}
msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.Error == "req_removed"
})
fatalErr(t, err)
msgs = Messages{}
s = NewBaseSignal("TEST", Down)
msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool {
return sig.ID() == s.ID()
})
fatalErr(t, err)
select {
case <- l2_listener.Chan:
t.Fatal("Recevied message on l2 after removing link")
default:
}
}
func Test10KLink(t *testing.T) {
@ -104,7 +71,7 @@ func Test10KLink(t *testing.T) {
listener_id := KeyID(l_pub)
child_policy := NewPerNodePolicy(map[NodeID]Tree{
listener_id: Tree{
LockSignalType.String(): nil,
uint64(LockSignalType): nil,
},
})
NewLockable := func()(*Node) {
@ -125,7 +92,7 @@ func Test10KLink(t *testing.T) {
ctx.Log.Logf("test", "CREATED_10K")
l_policy := NewAllNodesPolicy(Tree{
LockSignalType.String(): nil,
uint64(LockSignalType): nil,
})
listener := NewListenerExt(100000)
node := NewNode(ctx, listener_key, TestLockableType, 10000,
@ -140,14 +107,14 @@ func Test10KLink(t *testing.T) {
_, err = LockLockable(ctx, node, node.ID)
fatalErr(t, err)
_, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool {
return sig.Str == "locked"
_, err = WaitForSignal(listener.Chan, time.Millisecond*1000, func(sig *LockSignal) bool {
return sig.State == "locked"
})
fatalErr(t, err)
for _, _ = range(reqs) {
_, err := WaitForSignal(listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool {
return sig.Str == "locked"
_, err := WaitForSignal(listener.Chan, time.Millisecond*100, func(sig *LockSignal) bool {
return sig.State == "locked"
})
fatalErr(t, err)
}
@ -178,36 +145,36 @@ func TestLock(t *testing.T) {
l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
locked := func(sig *StringSignal) bool {
return sig.Str == "locked"
locked := func(sig *LockSignal) bool {
return sig.State == "locked"
}
unlocked := func(sig *StringSignal) bool {
return sig.Str == "unlocked"
unlocked := func(sig *LockSignal) bool {
return sig.State == "unlocked"
}
_, err := LockLockable(ctx, l0, l5.ID)
fatalErr(t, err)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, locked)
fatalErr(t, err)
id, err := LockLockable(ctx, l1, l1.ID)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
return sig.Error == "not_unlocked" && sig.ReqID() == id
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool {
return sig.Error == "not_unlocked" && sig.Header().ReqID == id
})
fatalErr(t, err)
_, err = UnlockLockable(ctx, l0, l5.ID)
fatalErr(t, err)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, unlocked)
fatalErr(t, err)
_, err = LockLockable(ctx, l1, l1.ID)
fatalErr(t, err)
for i := 0; i < 4; i++ {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool {
return sig.Str == "locked"
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *LockSignal) bool {
return sig.State == "locked"
})
fatalErr(t, err)
}

@ -29,22 +29,20 @@ const (
var (
// Base NodeID, used as a special value
ZeroUUID = uuid.UUID{}
ZeroID = NodeID{ZeroUUID}
ZeroID = NodeID(ZeroUUID)
)
// A NodeID uniquely identifies a Node
type NodeID struct {
uuid.UUID
type NodeID uuid.UUID
func (id NodeID) MarshalBinary() ([]byte, error) {
return (uuid.UUID)(id).MarshalBinary()
}
func (id NodeID) Serialize() []byte {
ser, _ := id.MarshalBinary()
return ser
func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
func IDFromBytes(bytes []byte) (NodeID, error) {
id, err := uuid.FromBytes(bytes[:])
return NodeID{id}, err
id, err := uuid.FromBytes(bytes)
return NodeID(id), err
}
// Parse an ID from a string
@ -53,26 +51,17 @@ func ParseID(str string) (NodeID, error) {
if err != nil {
return NodeID{}, err
}
return NodeID{id_uuid}, nil
return NodeID(id_uuid), nil
}
// Generate a random NodeID
func RandID() NodeID {
return NodeID{uuid.New()}
}
// A Serializable has a type that can be used to map to it, and a function to serialize` the current state
type Serializable[I comparable] interface {
Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
Type() I
return NodeID(uuid.New())
}
// Extensions are data attached to nodes that process signals
type Extension interface {
Serializable[ExtType]
Field(string)interface{}
Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages
Process(*Context, *Node, NodeID, Signal) Messages
}
// A QueuedSignal is a Signal that has been Queued to trigger at a set time
@ -130,10 +119,10 @@ const (
Pending
)
func (node *Node) Allows(principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) {
func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) {
pends := map[PolicyType]Messages{}
for policy_type, policy := range(node.Policies) {
msgs, resp := policy.Allows(principal_id, action, node)
msgs, resp := policy.Allows(ctx, principal_id, action, node)
if resp == Allow {
return nil, Allow
} else if resp == Pending {
@ -154,7 +143,7 @@ func (node *Node) QueueSignal(time time.Time, signal Signal) {
func (node *Node) DequeueSignal(id uuid.UUID) error {
idx := -1
for i, q := range(node.SignalQueue) {
if q.Signal.ID() == id {
if q.Signal.Header().ID == id {
idx = i
break
}
@ -202,16 +191,43 @@ func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
}
func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} {
exts := map[ExtType]map[string]interface{}{}
type StringError string
func (err StringError) String() string {
return string(err)
}
func (err StringError) Error() string {
return err.String()
}
func (err StringError) MarshalBinary() ([]byte, error) {
return []byte(string(err)), nil
}
func NewErrorField(fstring string, args ...interface{}) SerializedValue {
str := StringError(fmt.Sprintf(fstring, args...))
str_ser, err := str.MarshalBinary()
if err != nil {
panic(err)
}
return SerializedValue{
TypeStack: []uint64{uint64(ErrorType)},
Data: str_ser,
}
}
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue {
exts := map[ExtType]map[string]SerializedValue{}
for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{}
fields := map[string]SerializedValue{}
for _, req := range(field_reqs) {
ext, exists := node.Extensions[ext_type]
if exists == false {
fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type)
} else {
fields[req] = ext.Field(req)
f, err := SerializeField(ctx, ext, req)
if err != nil {
fields[req] = NewErrorField(err.Error())
} else {
fields[req] = f
}
}
}
exts[ext_type] = fields
@ -227,21 +243,40 @@ func nodeLoop(ctx *Context, node *Node) error {
}
// Perform startup actions
node.Process(ctx, ZeroID, &StartSignal)
for true {
node.Process(ctx, ZeroID, NewStartSignal())
run := true
for run == true {
var signal Signal
var source NodeID
select {
case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
ser, err := msg.Signal.Serialize()
signal_type, exists := ctx.SignalTypes[reflect.TypeOf(msg.Signal).Elem()]
if exists == false {
ctx.Log.Logf("signal", "SIGNAL_NOT_REGISTERED: %+v", reflect.TypeOf(msg.Signal).Elem())
}
signal_ser, err := SerializeSignal(ctx, signal, signal_type)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal)
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal)
}
ser, err := signal_ser.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser)
continue
}
sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...)
dst_id_ser, err := msg.Dest.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err)
continue
}
src_id_ser, err := msg.Source.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err)
continue
}
sig_data := append(dst_id_ser, src_id_ser...)
sig_data = append(sig_data, ser...)
validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature)
if validated == false {
@ -251,26 +286,26 @@ func nodeLoop(ctx *Context, node *Node) error {
princ_id := KeyID(msg.Principal)
if princ_id != node.ID {
pends, resp := node.Allows(princ_id, msg.Signal.Permission())
pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission())
if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.Header().ID, "acl denied"), msg.Source)
ctx.Send(msgs)
continue
} else if resp == Pending {
ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends)
timeout_signal := NewACLTimeoutSignal(msg.Signal.ID())
timeout_signal := NewACLTimeoutSignal(msg.Signal.Header().ID)
node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal)
msgs := Messages{}
for policy_type, sigs := range(pends) {
for _, m := range(sigs) {
msgs = append(msgs, m)
node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()}
node.PendingSignals[m.Signal.Header().ID] = PendingSignal{policy_type, false, msg.Signal.Header().ID}
}
}
node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source}
node.PendingACLs[msg.Signal.Header().ID] = PendingACL{len(msgs), timeout_signal.ID, msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source}
ctx.Send(msgs)
continue
} else if resp == Allow {
@ -290,7 +325,7 @@ func nodeLoop(ctx *Context, node *Node) error {
t := node.NextSignal.Time
i := -1
for j, queued := range(node.SignalQueue) {
if queued.Signal.ID() == node.NextSignal.Signal.ID() {
if queued.Signal.Header().ID == node.NextSignal.Signal.Header().ID {
i = j
break
}
@ -304,26 +339,26 @@ func nodeLoop(ctx *Context, node *Node) error {
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
if node.NextSignal == nil {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan)
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, signal, t, node.TimeoutChan)
} else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time)
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time)
}
}
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
info, waiting := node.PendingSignals[signal.ReqID()]
info, waiting := node.PendingSignals[signal.Header().ReqID]
if waiting == true {
if info.Found == false {
info.Found = true
node.PendingSignals[signal.ReqID()] = info
node.PendingSignals[signal.Header().ReqID] = info
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID]
if exists == true {
req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, signal)
allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal)
allowed := node.Policies[info.Policy].ContinueAllows(ctx, req_info, signal)
if allowed == Allow {
ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal)
signal = req_info.Signal
@ -337,7 +372,7 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source)
err := ctx.Send(msgs)
if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err)
@ -355,25 +390,20 @@ func nodeLoop(ctx *Context, node *Node) error {
}
}
// Handle node signals
if signal.Type() == StopSignalType {
switch sig := signal.(type) {
case *StopSignal:
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "stopped"), source)
ctx.Send(msgs)
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID))
break
} else if signal.Type() == ReadSignalType {
read_signal, ok := signal.(*ReadSignal)
if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else {
result := node.ReadFields(read_signal.Extensions)
node.Process(ctx, node.ID, NewStatusSignal(node.ID, "stopped"))
run = false
case *ReadSignal:
result := node.ReadFields(ctx, sig.Extensions)
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source)
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "read_done"), source)
ctx.Send(msgs)
}
}
node.Process(ctx, source, signal)
// assume that processing a signal means that this nodes state changed
@ -401,8 +431,8 @@ type Message struct {
}
type Messages []*Message
func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages {
msg, err := NewMessage(dest, source, principal, signal)
func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages {
msg, err := NewMessage(ctx, dest, source, principal, signal)
if err != nil {
panic(err)
} else {
@ -411,13 +441,31 @@ func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Sig
return msgs
}
func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) {
ser, err := signal.Serialize()
func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) {
signal_type, exists := ctx.SignalTypes[reflect.TypeOf(signal)]
if exists == false {
return nil, fmt.Errorf("Cannot put %+v in a message, not a known signal type", reflect.TypeOf(signal))
}
signal_ser, err := SerializeSignal(ctx, signal, signal_type)
if err != nil {
return nil, err
}
ser, err := signal_ser.MarshalBinary()
if err != nil {
return nil, err
}
sig_data := append(dest.Serialize(), source.Serialize()...)
dest_ser, err := dest.MarshalBinary()
if err != nil {
return nil, err
}
source_ser, err := source.MarshalBinary()
if err != nil {
return nil, err
}
sig_data := append(dest_ser, source_ser...)
sig_data = append(sig_data, ser...)
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
@ -435,7 +483,7 @@ func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal
}
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type())
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal)
messages := Messages{}
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
@ -449,120 +497,86 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
return ctx.Send(messages)
}
func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T
func GetCtx[T Extension, C any](ctx *Context, ext_type ExtType) (C, error) {
var zero_ctx C
ext_type := zero.Type()
type_hash := Hash(ext_type)
ext_info, ok := ctx.Extensions[type_hash]
ext_info, ok := ctx.Extensions[ext_type]
if ok == false {
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type)
return zero_ctx, fmt.Errorf("%+v is not an extension in ctx", ext_type)
}
ext_ctx, ok := ext_info.Data.(C)
if ok == false {
return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero))
return zero_ctx, fmt.Errorf("context for %+v is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero_ctx))
}
return ext_ctx, nil
}
func GetExt[T Extension](node *Node) (T, error) {
func GetExt[T Extension](node *Node, ext_type ExtType) (T, error) {
var zero T
ext_type := zero.Type()
ext, exists := node.Extensions[ext_type]
if exists == false {
return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions)
return zero, fmt.Errorf("%+v does not have %+v extension - %+v", node.ID, ext_type, node.Extensions)
}
ret, ok := ext.(T)
if ok == false {
return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
return zero, fmt.Errorf("%+v in %+v is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero))
}
return ret, nil
}
func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.Extensions))
qsignals := make([]QSignalDB, len(node.SignalQueue))
policies := make([]PolicyDB, len(node.Policies))
func (node *Node) Serialize(ctx *Context) (SerializedValue, error) {
if node == nil {
return SerializedValue{}, fmt.Errorf("Cannot serialize nil Node")
}
node_bytes := make([]byte, 8 * 3)
binary.BigEndian.PutUint64(node_bytes[0:8], uint64(len(node.Extensions)))
binary.BigEndian.PutUint64(node_bytes[8:16], uint64(len(node.Policies)))
binary.BigEndian.PutUint64(node_bytes[16:24], uint64(len(node.SignalQueue)))
key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil {
return nil, err
return SerializedValue{}, err
}
node_db := NodeDB{
Header: NodeDBHeader{
Magic: NODE_DB_MAGIC,
TypeHash: Hash(node.Type),
KeyLength: uint32(len(key_bytes)),
BufferSize: node.BufferSize,
NumExtensions: uint32(len(extensions)),
NumPolicies: uint32(len(policies)),
NumQueuedSignals: uint32(len(node.SignalQueue)),
},
Extensions: extensions,
Policies: policies,
QueuedSignals: qsignals,
KeyBytes: key_bytes,
key_val := SerializedValue{
TypeStack: []uint64{uint64(NodeKeyType)},
Data: key_bytes,
}
i := 0
for ext_type, info := range(node.Extensions) {
ser, err := info.Serialize()
key_ser, err := key_val.MarshalBinary()
if err != nil {
return nil, err
}
node_db.Extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: Hash(ext_type),
Length: uint64(len(ser)),
},
Data: ser,
}
i += 1
return SerializedValue{}, err
}
node_bytes = append(node_bytes, key_ser...)
for i, qsignal := range(node.SignalQueue) {
ser, err := qsignal.Signal.Serialize()
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("serialize", "SERIALIZING_EXTENSION: %+v", ext)
ext_ser, err := SerializeExtension(ctx, ext, ext_type)
if err != nil {
return nil, err
}
node_db.QueuedSignals[i] = QSignalDB{
QSignalDBHeader{
qsignal.Time,
Hash(qsignal.Signal.Type()),
uint64(len(ser)),
},
ser,
return SerializedValue{}, err
}
}
i = 0
for _, policy := range(node.Policies) {
ser, err := policy.Serialize()
ext_bytes, err := ext_ser.MarshalBinary()
if err != nil {
return nil, err
return SerializedValue{}, err
}
node_db.Policies[i] = PolicyDB{
PolicyDBHeader{
Hash(policy.Type()),
uint64(len(ser)),
},
ser,
node_bytes = append(node_bytes, ext_bytes...)
}
node_value := SerializedValue{
TypeStack: []uint64{uint64(node.Type)},
Data: node_bytes,
}
return node_db.Serialize(), nil
return node_value, nil
}
func KeyID(pub ed25519.PublicKey) NodeID {
str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID{str}
id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
return NodeID(id)
}
// Create a new node in memory and start it's event loop
@ -584,24 +598,28 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic("Attempted to create an existing node")
}
def, exists := ctx.Types[Hash(node_type)]
def, exists := ctx.Nodes[node_type]
if exists == false {
panic("Node type %s not registered in Context")
}
ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) {
_, exists := ext_map[ext.Type()]
ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext)]
if exists == false {
panic(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext)))
}
_, exists = ext_map[ext_type]
if exists == true {
panic("Cannot add the same extension to a node twice")
}
ext_map[ext.Type()] = ext
ext_map[ext_type] = ext
}
for _, required_ext := range(def.Extensions) {
_, exists := ext_map[required_ext]
if exists == false {
panic(fmt.Sprintf("%s requires %s", node_type, required_ext))
panic(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
}
}
@ -610,9 +628,9 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
}
default_policy := NewAllNodesPolicy(Tree{
ErrorSignalType.String(): nil,
ReadResultSignalType.String(): nil,
StatusSignalType.String(): nil,
uint64(ErrorSignalType): nil,
uint64(ReadResultSignalType): nil,
uint64(StatusSignalType): nil,
})
all_nodes_policy, exists := policies[AllNodesPolicyType]
@ -642,252 +660,32 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic(err)
}
node.Process(ctx, ZeroID, &NewSignal)
node.Process(ctx, ZeroID, NewCreateSignal())
go runNode(ctx, node)
return node
}
type PolicyDBHeader struct {
TypeHash uint64
Length uint64
}
type PolicyDB struct {
Header PolicyDBHeader
Data []byte
}
type QSignalDBHeader struct {
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct {
Magic uint32
NumExtensions uint32
NumPolicies uint32
NumQueuedSignals uint32
BufferSize uint32
KeyLength uint32
TypeHash uint64
}
type NodeDB struct {
Header NodeDBHeader
Extensions []ExtensionDB
Policies []PolicyDB
QueuedSignals []QSignalDB
KeyBytes []byte
}
//TODO: add size safety checks
func NewNodeDB(data []byte) (NodeDB, error) {
var zero NodeDB
ptr := 0
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
num_policies := binary.BigEndian.Uint32(data[8:12])
num_queued_signals := binary.BigEndian.Uint32(data[12:16])
buffer_size := binary.BigEndian.Uint32(data[16:20])
key_length := binary.BigEndian.Uint32(data[20:24])
node_type_hash := binary.BigEndian.Uint64(data[24:32])
ptr += NODE_DB_HEADER_LEN
if magic != NODE_DB_MAGIC {
return zero, fmt.Errorf("header has incorrect magic 0x%x", magic)
}
key_bytes := make([]byte, key_length)
n := copy(key_bytes, data[ptr:(ptr+int(key_length))])
if n != int(key_length) {
return zero, fmt.Errorf("not enough key bytes: %d", n)
}
ptr += int(key_length)
extensions := make([]ExtensionDB, num_extensions)
for i, _ := range(extensions) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(EXTENSION_DB_HEADER_LEN)
data_end := data_start + length
ext_data := cur[data_start:data_end]
extensions[i] = ExtensionDB{
Header: ExtensionDBHeader{
TypeHash: type_hash,
Length: length,
},
Data: ext_data,
}
ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
policies := make([]PolicyDB, num_policies)
for i, _ := range(policies) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(POLICY_DB_HEADER_LEN)
data_end := data_start + length
policy_data := cur[data_start:data_end]
policies[i] = PolicyDB{
PolicyDBHeader{
type_hash,
length,
},
policy_data,
}
ptr += int(POLICY_DB_HEADER_LEN + length)
}
queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) {
cur := data[ptr:]
// TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented
// Right now causes segfault because any saved signal is loaded as nil
unix_milli := binary.BigEndian.Uint64(cur[0:8])
type_hash := binary.BigEndian.Uint64(cur[8:16])
signal_size := binary.BigEndian.Uint64(cur[16:24])
signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)]
queued_signals[i] = QSignalDB{
QSignalDBHeader{
time.UnixMilli(int64(unix_milli)),
type_hash,
signal_size,
},
signal_data,
}
ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size)
}
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
TypeHash: node_type_hash,
BufferSize: buffer_size,
KeyLength: key_length,
NumExtensions: num_extensions,
NumQueuedSignals: num_queued_signals,
},
KeyBytes: key_bytes,
Extensions: extensions,
QueuedSignals: queued_signals,
}, nil
}
func (header NodeDBHeader) Serialize() []byte {
if header.Magic != NODE_DB_MAGIC {
panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic))
}
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies)
binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[16:20], header.BufferSize)
binary.BigEndian.PutUint32(ret[20:24], header.KeyLength)
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
return ret
}
func (node NodeDB) Serialize() []byte {
ser := node.Header.Serialize()
ser = append(ser, node.KeyBytes...)
for _, extension := range(node.Extensions) {
ser = append(ser, extension.Serialize()...)
}
for _, policy := range(node.Policies) {
ser = append(ser, policy.Serialize()...)
}
for _, qsignal := range(node.QueuedSignals) {
ser = append(ser, qsignal.Serialize()...)
}
return ser
}
func (header QSignalDBHeader) Serialize() []byte {
ret := make([]byte, QSIGNAL_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], uint64(header.Time.UnixMilli()))
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash)
binary.BigEndian.PutUint64(ret[16:24], header.Length)
return ret
}
func (qsignal QSignalDB) Serialize() []byte {
header_bytes := qsignal.Header.Serialize()
return append(header_bytes, qsignal.Data...)
}
func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[8:16], header.Length)
return ret
}
func (extension ExtensionDB) Serialize() []byte {
header_bytes := extension.Header.Serialize()
return append(header_bytes, extension.Data...)
}
func (header PolicyDBHeader) Serialize() []byte {
ret := make([]byte, POLICY_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
binary.BigEndian.PutUint64(ret[0:8], header.Length)
return ret
}
func (policy PolicyDB) Serialize() []byte {
header_bytes := policy.Header.Serialize()
return append(header_bytes, policy.Data...)
}
// Write a node to the database
func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
bytes, err := node.Serialize()
node_serialized, err := node.Serialize(ctx)
if err != nil {
return err
}
bytes, err := node_serialized.MarshalBinary()
if err != nil {
return err
}
ctx.Log.Logf("db_data", "DB_DATA: %+v", bytes)
id_bytes := node.ID.Serialize()
id_bytes, err := node.ID.MarshalBinary()
if err != nil {
return err
}
ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes)
return ctx.DB.Update(func(txn *badger.Txn) error {
@ -895,11 +693,15 @@ func WriteNode(ctx *Context, node *Node) error {
})
}
//TODO: fix after capnp
func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOADING_NODE: %s", id)
var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error {
id_bytes := id.Serialize()
id_bytes, err := id.MarshalBinary()
if err != nil {
return err
}
ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes)
item, err := txn.Get(id_bytes)
if err != nil {
@ -917,137 +719,18 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err
}
// Parse the bytes from the DB
node_db, err := NewNodeDB(bytes)
if err != nil {
return nil, err
}
policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies)
for _, policy_db := range(node_db.Policies) {
policy_info, exists := ctx.Policies[policy_db.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash)
}
policy, err := policy_info.Load(ctx, policy_db.Data)
if err != nil {
return nil, err
}
num_extensions := binary.BigEndian.Uint64(bytes[0:8])
num_policies := binary.BigEndian.Uint64(bytes[8:16])
num_signals := binary.BigEndian.Uint64(bytes[16:24])
print(num_extensions)
print(num_policies)
print(num_signals)
policies[policy_info.Type] = policy
}
key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil {
return nil, err
}
var key ed25519.PrivateKey
switch k := key_raw.(type) {
case ed25519.PrivateKey:
key = k
default:
return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k))
}
key_id := KeyID(key.Public().(ed25519.PublicKey))
if key_id != id {
return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id)
}
node_type, known := ctx.Types[node_db.Header.TypeHash]
if known == false {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
}
signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals)
for i, qsignal := range(node_db.QueuedSignals) {
sig_info, exists := ctx.Signals[qsignal.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash)
}
signal, err := sig_info.Load(ctx, qsignal.Data)
if err != nil {
return nil, err
}
signal_queue[i] = QueuedSignal{signal, qsignal.Header.Time}
}
next_signal, timeout_chan := SoonestSignal(signal_queue)
node := &Node{
Key: key,
ID: key_id,
Type: node_type.Type,
Extensions: map[ExtType]Extension{},
Policies: policies,
MsgChan: make(chan *Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan,
SignalQueue: signal_queue,
NextSignal: next_signal,
}
/*
ctx.AddNode(id, node)
found_extensions := []ExtType{}
// Parse each of the extensions from the db
for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
def, known := ctx.Extensions[type_hash]
if known == false {
return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash)
}
ctx.Log.Logf("db", "DB_EXTENSION_LOADING: %s/%s", id, def.Type)
extension, err := def.Load(ctx, ext_db.Data)
if err != nil {
return nil, err
}
node.Extensions[def.Type] = extension
found_extensions = append(found_extensions, def.Type)
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s/%s - %+v", id, def.Type, extension)
}
missing_extensions := []ExtType{}
for _, ext := range(node_type.Extensions) {
found := false
for _, found_ext := range(found_extensions) {
if found_ext == ext {
found = true
break
}
}
if found == false {
missing_extensions = append(missing_extensions, ext)
}
}
if len(missing_extensions) > 0 {
return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions)
}
extra_extensions := []ExtType{}
for _, found_ext := range(found_extensions) {
found := false
for _, ext := range(node_type.Extensions) {
if ext == found_ext {
found = true
break
}
}
if found == false {
extra_extensions = append(extra_extensions, found_ext)
}
}
if len(extra_extensions) > 0 {
ctx.Log.Logf("db", "DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions)
}
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
go runNode(ctx, node)
*/
return node, nil
return nil, nil
}

@ -8,21 +8,21 @@ import (
)
func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"signal", "node", "db"})
node_type := NodeType("test")
ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize"})
node_type := NewNodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err)
node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil))
node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil))
ctx.Nodes = map[NodeID]*Node{}
_, err = ctx.GetNode(node.ID)
ctx.nodeMap = map[NodeID]*Node{}
_, err = ctx.getNode(node.ID)
fatalErr(t, err)
}
func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
node_type := NodeType("TEST")
node_type := NewNodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType})
fatalErr(t, err)
@ -38,27 +38,27 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N2: %s", n2_id)
n1_policy := NewPerNodePolicy(map[NodeID]Tree{
n2_id: Tree{
ReadSignalType.String(): nil,
n2_id: {
uint64(ReadSignalType): nil,
},
})
n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), n2_listener)
n1 := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{
PerNodePolicyType: &n1_policy,
}, NewGroupExt(nil), NewECDHExt())
}, NewGroupExt(nil))
read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"},
GroupExtType: {"members"},
})
msgs := Messages{}
msgs = msgs.Add(n2.ID, n2.Key, read_sig, n1.ID)
msgs = msgs.Add(ctx, n2.ID, n2.Key, read_sig, n1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool {
return true
})
fatalErr(t, err)

@ -1,35 +1,26 @@
package graphvent
import (
"encoding/json"
)
const (
MemberOfPolicyType = PolicyType("USER_OF")
RequirementOfPolicyType = PolicyType("REQUIEMENT_OF")
PerNodePolicyType = PolicyType("PER_NODE")
AllNodesPolicyType = PolicyType("ALL_NODES")
)
type Policy interface {
Serializable[PolicyType]
Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
ContinueAllows(current PendingACL, signal Signal)RuleResult
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
// Merge with another policy of the same underlying type
Merge(Policy) Policy
// Make a copy of this policy
Copy() Policy
}
func (policy *AllNodesPolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
func (policy *AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
return nil, policy.Rules.Allows(action)
}
func (policy *AllNodesPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult {
func (policy *AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
func (policy *PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
for id, actions := range(policy.NodeRules) {
if id != principal_id {
continue
@ -39,7 +30,7 @@ func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node
return nil, Deny
}
func (policy *PerNodePolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult {
func (policy *PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
@ -57,7 +48,7 @@ func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy {
}
}
func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult {
func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
@ -68,7 +59,17 @@ func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Sig
return Deny
}
requirements, ok := ext["requirements"].(map[NodeID]string)
reqs_ser, ok := ext["requirements"]
if ok == false {
return Deny
}
reqs_if, err := DeserializeValue(ctx, reqs_ser)
if err != nil {
return Deny
}
requirements, ok := reqs_if.(map[NodeID]ReqState)
if ok == false {
return Deny
}
@ -96,7 +97,7 @@ func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy {
}
}
func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult {
func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
@ -107,7 +108,17 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal)
return Deny
}
members, ok := group_ext_data["members"].(map[NodeID]string)
members_ser, ok := group_ext_data["members"]
if ok == false {
return Deny
}
members_if, err := DeserializeValue(ctx, members_ser)
if err != nil {
return Deny
}
members, ok := members_if.(map[NodeID]string)
if ok == false {
return Deny
}
@ -122,11 +133,11 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal)
}
// Send a read signal to Group to check if principal_id is a member of it
func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
msgs := Messages{}
for id, rule := range(policy.NodeRules) {
if id == node.ID {
ext, err := GetExt[*GroupExt](node)
ext, err := GetExt[*GroupExt](node, GroupExtType)
if err == nil {
for member, _ := range(ext.Members) {
if member == principal_id {
@ -137,7 +148,7 @@ func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Nod
}
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewReadSignal(map[ExtType][]string{
msgs = msgs.Add(ctx, node.ID, node.Key, NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"},
}), id)
}
@ -238,7 +249,7 @@ func (policy *AllNodesPolicy) Copy() Policy {
}
}
type Tree map[string]Tree
type Tree map[uint64]Tree
func (rule Tree) Allows(action Tree) RuleResult {
// If the current rule is nil, it's a wildcard and any action being processed is allowed
@ -285,14 +296,6 @@ func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType
}
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
return AllNodesPolicy{
Rules: rules,
@ -307,15 +310,7 @@ func (policy *AllNodesPolicy) Type() PolicyType {
return AllNodesPolicyType
}
func (policy *AllNodesPolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
var DefaultPolicy = NewAllNodesPolicy(Tree{
ErrorSignalType.String(): nil,
ReadResultSignalType.String(): nil,
uint64(ErrorSignalType): nil,
uint64(ReadResultSignalType): nil,
})

@ -1,49 +1,29 @@
package graphvent
import (
"time"
"fmt"
"encoding/json"
"encoding/binary"
"crypto"
"crypto/ed25519"
"crypto/ecdh"
"crypto/rand"
"crypto/aes"
"crypto/cipher"
"time"
"capnproto.org/go/capnp/v3"
"github.com/google/uuid"
schema "github.com/mekkanized/graphvent/signal"
)
type SignalDirection int
const (
StopSignalType = SignalType("STOP")
NewSignalType = SignalType("NEW")
StartSignalType = SignalType("START")
ErrorSignalType = SignalType("ERROR")
StatusSignalType = SignalType("STATUS")
LinkSignalType = SignalType("LINK")
LockSignalType = SignalType("LOCK")
ReadSignalType = SignalType("READ")
ReadResultSignalType = SignalType("READ_RESULT")
ECDHSignalType = SignalType("ECDH")
ECDHProxySignalType = SignalType("ECDH_PROXY")
ACLTimeoutSignalType = SignalType("ACL_TIMEOUT")
Up SignalDirection = iota
Down
Direct
)
type SignalType string
func (signal_type SignalType) String() string { return string(signal_type) }
func (signal_type SignalType) Prefix() string { return "SIGNAL: " }
type SignalHeader struct {
Direction SignalDirection
ID uuid.UUID
ReqID uuid.UUID
}
type Signal interface {
Serializable[SignalType]
String() string
Direction() SignalDirection
ID() uuid.UUID
ReqID() uuid.UUID
Header() *SignalHeader
Permission() Tree
}
@ -59,7 +39,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU
if signal == nil {
return nil, fmt.Errorf("LISTENER_CLOSED")
}
if signal.ReqID() == req_id {
if signal.Header().ReqID == req_id {
return signal, nil
}
case <-timeout_channel:
@ -69,7 +49,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU
return nil, fmt.Errorf("UNREACHABLE")
}
func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) {
func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check func(S)bool) (S, error) {
var zero S
var timeout_channel <- chan time.Time
if timeout > 0 {
@ -79,493 +59,396 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal
select {
case signal := <- listener:
if signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type)
return zero, fmt.Errorf("LISTENER_CLOSED")
}
if signal.Type() == signal_type {
sig, ok := signal.(S)
if ok == true {
if check(sig) == true {
return sig, nil
}
}
}
case <-timeout_channel:
return zero, fmt.Errorf("LISTENER_TIMEOUT: %s", signal_type)
return zero, fmt.Errorf("LISTENER_TIMEOUT")
}
}
return zero, fmt.Errorf("LOOP_ENDED")
}
type BaseSignal struct {
SignalDirection SignalDirection `json:"direction"`
SignalType SignalType `json:"type"`
UUID uuid.UUID `json:"id"`
ReqUUID uuid.UUID `json:"req_uuid"`
func NewSignalHeader(direction SignalDirection) SignalHeader {
id := uuid.New()
header := SignalHeader{
ID: id,
ReqID: id,
Direction: direction,
}
func (signal *BaseSignal) ReqID() uuid.UUID {
return signal.ReqUUID
return header
}
func (signal *BaseSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
func NewRespHeader(req_id uuid.UUID, direction SignalDirection) SignalHeader {
header := SignalHeader{
ID: uuid.New(),
ReqID: req_id,
Direction: direction,
}
return header
}
func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, signal)
func SerializeHeader(header SignalHeader, root schema.SignalHeader) error {
root.SetDirection(uint8(header.Direction))
id_ser, err := header.ID.MarshalBinary()
if err != nil {
return err
}
root.SetId(id_ser)
func (signal *BaseSignal) ID() uuid.UUID {
return signal.UUID
req_id_ser, err := header.ReqID.MarshalBinary()
if err != nil {
return err
}
root.SetReqID(req_id_ser)
return nil
}
func (signal *BaseSignal) Type() SignalType {
return signal.SignalType
func DeserializeHeader(header schema.SignalHeader) (SignalHeader, error) {
id_ser, err := header.Id()
if err != nil {
return SignalHeader{}, err
}
id, err := uuid.FromBytes(id_ser)
if err != nil {
return SignalHeader{}, err
}
func (signal *BaseSignal) Permission() Tree {
return Tree{
string(signal.Type()): Tree{},
req_id_ser, err := header.ReqID()
if err != nil {
return SignalHeader{}, err
}
req_id, err := uuid.FromBytes(req_id_ser)
if err != nil {
return SignalHeader{}, err
}
func (signal *BaseSignal) Direction() SignalDirection {
return signal.SignalDirection
return SignalHeader{
ID: id,
ReqID: req_id,
Direction: SignalDirection(header.Direction()),
}, nil
}
func (signal *BaseSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
type CreateSignal struct {
SignalHeader
}
func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal {
id := uuid.New()
signal := BaseSignal{
UUID: id,
ReqUUID: id,
SignalDirection: direction,
SignalType: signal_type,
func (signal *CreateSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *CreateSignal) Permission() Tree {
return Tree{
uint64(CreateSignalType): nil,
}
return signal
}
func NewRespSignal(id uuid.UUID, signal_type SignalType, direction SignalDirection) BaseSignal {
signal := BaseSignal{
UUID: uuid.New(),
ReqUUID: id,
SignalDirection: direction,
SignalType: signal_type,
func NewCreateSignal() *CreateSignal {
return &CreateSignal{
NewSignalHeader(Direct),
}
return signal
}
var NewSignal = NewBaseSignal(NewSignalType, Direct)
var StartSignal = NewBaseSignal(StartSignalType, Direct)
var StopSignal = NewBaseSignal(StopSignalType, Direct)
type IDSignal struct {
BaseSignal
NodeID `json:"id"`
type StartSignal struct {
SignalHeader
}
func (signal *IDSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
func (signal *StartSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
type StringSignal struct {
BaseSignal
Str string `json:"state"`
func (signal *StartSignal) Permission() Tree {
return Tree{
uint64(StartSignalType): nil,
}
func (signal *StringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *StringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
func NewStartSignal() *StartSignal {
return &StartSignal{
NewSignalHeader(Direct),
}
type ErrorSignal struct {
BaseSignal
Error string
}
func (signal *ErrorSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
type StopSignal struct {
SignalHeader
}
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal {
return &ErrorSignal{
NewRespSignal(req_id, ErrorSignalType, Direct),
fmt.Sprintf(fmt_string, args...),
func (signal *StopSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *StopSignal) Permission() Tree {
return Tree{
uint64(StopSignalType): nil,
}
func NewACLTimeoutSignal(req_id uuid.UUID) Signal {
sig := NewRespSignal(req_id, ACLTimeoutSignalType, Direct)
return &sig
}
type IDStringSignal struct {
BaseSignal
NodeID NodeID `json:"node_id"`
Str string `json:"string"`
func NewStopSignal() *StopSignal {
return &StopSignal{
NewSignalHeader(Direct),
}
func (signal *IDStringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
type ErrorSignal struct {
SignalHeader
Error string
}
func NewStatusSignal(status string, source NodeID) Signal {
return &IDStringSignal{
BaseSignal: NewBaseSignal(StatusSignalType, Up),
NodeID: source,
Str: status,
func (signal *ErrorSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *ErrorSignal) MarshalBinary() ([]byte, error) {
arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena)
if err != nil {
return nil, err
}
func NewLinkSignal(state string, id NodeID) Signal {
return &IDStringSignal{
BaseSignal: NewBaseSignal(LinkSignalType, Direct),
NodeID: id,
Str: state,
}
root, err := schema.NewRootErrorSignal(seg)
if err != nil {
return nil, err
}
func NewLockSignal(state string) Signal {
return &StringSignal{
NewBaseSignal(LockSignalType, Direct),
state,
root.SetError(signal.Error)
header, err := root.NewHeader()
if err != nil {
return nil, err
}
err = SerializeHeader(signal.SignalHeader, header)
if err != nil {
return nil, err
}
func (signal *StringSignal) Permission() Tree {
return Tree{
string(signal.Type()): Tree{
signal.Str: Tree{},
},
return msg.Marshal()
}
func (signal *ErrorSignal) Deserialize(ctx *Context, data []byte) error {
msg, err := capnp.Unmarshal(data)
if err != nil {
return err
}
type ReadSignal struct {
BaseSignal
Extensions map[ExtType][]string `json:"extensions"`
root, err := schema.ReadRootErrorSignal(msg)
if err != nil {
return err
}
func (signal *ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
header, err := root.Header()
if err != nil {
return err
}
func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
return &ReadSignal{
NewBaseSignal(ReadSignalType, Direct),
exts,
signal.Error, err = root.Error()
if err != nil {
return err
}
signal.SignalHeader, err = DeserializeHeader(header)
if err != nil {
return err
}
func (signal *ReadSignal) Permission() Tree {
ret := Tree{}
for ext, fields := range(signal.Extensions) {
field_tree := Tree{}
for _, field := range(fields) {
field_tree[field] = Tree{}
return nil
}
ret[ext.String()] = field_tree
func (signal *ErrorSignal) Permission() Tree {
return Tree{
uint64(ErrorSignalType): nil,
}
return Tree{ReadSignalType.String(): ret}
}
type ReadResultSignal struct {
BaseSignal
NodeID NodeID
NodeType NodeType
Extensions map[ExtType]map[string]interface{} `json:"extensions"`
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal {
return &ErrorSignal{
NewRespHeader(req_id, Direct),
fmt.Sprintf(fmt_string, args...),
}
}
func (signal *ReadResultSignal) Permission() Tree {
type ACLTimeoutSignal struct {
SignalHeader
}
func (signal *ACLTimeoutSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *ACLTimeoutSignal) Permission() Tree {
return Tree{
ReadResultSignalType.String(): Tree{},
uint64(ACLTimeoutSignalType): nil,
}
}
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal {
return &ReadResultSignal{
NewRespSignal(req_id, ReadResultSignalType, Direct),
node_id,
node_type,
exts,
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
sig := &ACLTimeoutSignal{
NewRespHeader(req_id, Direct),
}
return sig
}
type ECDHSignal struct {
StringSignal
Time time.Time
EDDSA ed25519.PublicKey
ECDH *ecdh.PublicKey
Signature []byte
type StatusSignal struct {
SignalHeader
Source NodeID
Status string
}
type ECDHSignalJSON struct {
StringSignal
Time time.Time `json:"time"`
EDDSA []byte `json:"ecdsa_pubkey"`
ECDH []byte `json:"ecdh_pubkey"`
Signature []byte `json:"signature"`
func (signal *StatusSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func (signal *ECDHSignal) MarshalJSON() ([]byte, error) {
return json.Marshal(&ECDHSignalJSON{
StringSignal: signal.StringSignal,
Time: signal.Time,
ECDH: signal.ECDH.Bytes(),
EDDSA: signal.ECDH.Bytes(),
Signature: signal.Signature,
})
func (signal *StatusSignal) Permission() Tree {
return Tree{
uint64(StatusSignalType): nil,
}
func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) {
ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
func NewStatusSignal(source NodeID, status string) *StatusSignal {
return &StatusSignal{
NewSignalHeader(Up),
source,
status,
}
now := time.Now()
time_bytes, err := now.MarshalJSON()
if err != nil {
return nil, nil, err
}
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, nil, err
type LinkSignal struct {
SignalHeader
NodeID
Action string
}
return &ECDHSignal{
StringSignal: StringSignal{
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "req",
},
Time: now,
EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(),
Signature: sig,
}, ec_key, nil
func (signal *LinkSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now()
const (
LinkActionBase = "LINK_ACTION"
LinkActionAdd = "ADD"
)
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
if err != nil {
return ECDHSignal{}, nil, err
func (signal *LinkSignal) Permission() Tree {
return Tree{
uint64(LinkSignalType): Tree{
Hash(LinkActionBase, signal.Action): nil,
},
}
ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
}
shared_secret, err := ec_key.ECDH(req.ECDH)
if err != nil {
return ECDHSignal{}, nil, err
func NewLinkSignal(action string, id NodeID) Signal {
return &LinkSignal{
NewSignalHeader(Direct),
id,
action,
}
time_bytes, err := now.MarshalJSON()
if err != nil {
return ECDHSignal{}, nil, err
}
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return ECDHSignal{}, nil, err
type LockSignal struct {
SignalHeader
State string
}
return ECDHSignal{
StringSignal: StringSignal{
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "resp",
},
Time: now,
EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(),
Signature: sig,
}, shared_secret, nil
func (signal *LockSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error {
earliest := now.Add(-window)
latest := now.Add(window)
if sig.Time.Compare(earliest) == -1 {
return fmt.Errorf("TIME_TOO_LATE: %+v", sig.Time)
} else if sig.Time.Compare(latest) == 1 {
return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time)
}
const (
LockStateBase = "LOCK_STATE"
)
time_bytes, err := sig.Time.MarshalJSON()
if err != nil {
return err
func (signal *LockSignal) Permission() Tree {
return Tree{
uint64(LockSignalType): Tree{
Hash(LockStateBase, signal.State): nil,
},
}
sig_data := append(sig.ECDH.Bytes(), time_bytes...)
verified := ed25519.Verify(sig.EDDSA, sig_data, sig.Signature)
if verified == false {
return fmt.Errorf("Failed to verify signature")
}
return nil
func NewLockSignal(state string) *LockSignal {
return &LockSignal{
NewSignalHeader(Direct),
state,
}
type ECDHProxySignal struct {
BaseSignal
Source NodeID
Dest NodeID
IV []byte
Data []byte
}
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) {
if shared_secret == nil {
return nil, fmt.Errorf("need shared_secret")
type ReadSignal struct {
SignalHeader
Extensions map[ExtType][]string `json:"extensions"`
}
aes_key, err := aes.NewCipher(shared_secret[:32])
func (signal *ReadSignal) MarshalBinary() ([]byte, error) {
arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena)
if err != nil {
return nil, err
}
ser, err := SerializeSignal(signal, aes_key.BlockSize())
root, err := schema.NewRootReadSignal(seg)
if err != nil {
return nil, err
}
iv := make([]byte, aes_key.BlockSize())
n, err := rand.Reader.Read(iv)
header, err := root.NewHeader()
if err != nil {
return nil, err
} else if n != len(iv) {
return nil, fmt.Errorf("Not enough bytes read for IV")
}
encrypter := cipher.NewCBCEncrypter(aes_key, iv)
encrypter.CryptBlocks(ser, ser)
return &ECDHProxySignal{
BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct),
Source: source,
Dest: dest,
IV: iv,
Data: ser,
}, nil
err = SerializeHeader(signal.SignalHeader, header)
if err != nil {
return nil, err
}
type SignalHeader struct {
Magic uint32
TypeHash uint64
Length uint64
extensions, err := root.NewExtensions(int32(len(signal.Extensions)))
if err != nil {
return nil, err
}
const SIGNAL_SER_MAGIC uint32 = 0x753a64de
const SIGNAL_SER_HEADER_LENGTH = 20
func SerializeSignal(signal Signal, block_size int) ([]byte, error) {
signal_ser, err := signal.Serialize()
i := 0
for ext_type, fields := range(signal.Extensions) {
extension := extensions.At(i)
extension.SetType(uint64(ext_type))
f, err := extension.NewFields(int32(len(fields)))
if err != nil {
return nil, err
}
pad_req := 0
if block_size > 0 {
pad := block_size - ((SIGNAL_SER_HEADER_LENGTH + len(signal_ser)) % block_size)
if pad != block_size {
pad_req = pad
for j, field := range(fields) {
err := f.Set(j, field)
if err != nil {
return nil, err
}
}
header := SignalHeader{
Magic: SIGNAL_SER_MAGIC,
TypeHash: Hash(signal.Type()),
Length: uint64(len(signal_ser) + pad_req),
i += 1
}
ser := make([]byte, SIGNAL_SER_HEADER_LENGTH + len(signal_ser) + pad_req)
binary.BigEndian.PutUint32(ser[0:4], header.Magic)
binary.BigEndian.PutUint64(ser[4:12], header.TypeHash)
binary.BigEndian.PutUint64(ser[12:20], header.Length)
copy(ser[SIGNAL_SER_HEADER_LENGTH:], signal_ser)
return ser, nil
return msg.Marshal()
}
func ParseSignal(ctx *Context, data []byte) (Signal, error) {
if len(data) < SIGNAL_SER_HEADER_LENGTH {
return nil, fmt.Errorf("data shorter than header length")
func (signal *ReadSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
header := SignalHeader{
Magic: binary.BigEndian.Uint32(data[0:4]),
TypeHash: binary.BigEndian.Uint64(data[4:12]),
Length: binary.BigEndian.Uint64(data[12:20]),
func (signal *ReadSignal) Permission() Tree {
ret := Tree{}
for ext, fields := range(signal.Extensions) {
field_tree := Tree{}
for _, field := range(fields) {
field_tree[Hash(FieldNameBase, field)] = nil
}
if header.Magic != SIGNAL_SER_MAGIC {
return nil, fmt.Errorf("signal magic mismatch 0x%x", header.Magic)
ret[uint64(ext)] = field_tree
}
left := len(data) - SIGNAL_SER_HEADER_LENGTH
if int(header.Length) != left {
return nil, fmt.Errorf("signal length mismatch %d/%d", header.Length, left)
return Tree{uint64(ReadSignalType): ret}
}
signal_def, exists := ctx.Signals[header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", header.TypeHash)
func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
return &ReadSignal{
NewSignalHeader(Direct),
exts,
}
signal, err := signal_def.Load(ctx, data[SIGNAL_SER_HEADER_LENGTH:])
if err != nil {
return nil, err
}
return signal, nil
type ReadResultSignal struct {
SignalHeader
NodeID NodeID
NodeType NodeType
Extensions map[ExtType]map[string]SerializedValue
}
func ParseECDHProxySignal(ctx *Context, signal *ECDHProxySignal, shared_secret []byte) (Signal, error) {
if shared_secret == nil {
return nil, fmt.Errorf("need shared_secret")
func (signal *ReadResultSignal) Header() *SignalHeader {
return &signal.SignalHeader
}
aes_key, err := aes.NewCipher(shared_secret[:32])
if err != nil {
return nil, err
func (signal *ReadResultSignal) Permission() Tree {
return Tree{
uint64(ReadResultSignalType): nil,
}
decrypter := cipher.NewCBCDecrypter(aes_key, signal.IV)
decrypted := make([]byte, len(signal.Data))
decrypter.CryptBlocks(decrypted, signal.Data)
wrapped_signal, err := ParseSignal(ctx, decrypted)
if err != nil {
return nil, err
}
return wrapped_signal, nil
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal {
return &ReadResultSignal{
NewRespHeader(req_id, Direct),
node_id,
node_type,
exts,
}
}

@ -0,0 +1,10 @@
module github.com/mekkanized/graphvent/signal
go 1.21.0
require capnproto.org/go/capnp/v3 v3.0.0-alpha-29
require (
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect
zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect
)

@ -0,0 +1,35 @@
using Go = import "/go.capnp";
@0x83a311a90429ef94;
$Go.package("signal");
$Go.import("github.com/mekkanized/graphvent/signal");
struct SignalHeader {
direction @0 :UInt8;
id @1 :Data;
reqID @2 :Data;
}
struct ErrorSignal {
header @0 :SignalHeader;
error @1 :Text;
}
struct LinkSignal {
header @0 :SignalHeader;
action @1 :Text;
id @2 :Data;
}
struct LockSignal {
header @0 :SignalHeader;
state @1 :Text;
}
struct ReadSignal {
header @0 :SignalHeader;
extensions @1 :List(Extension);
struct Extension {
type @0 :UInt64;
fields @1 :List(Text);
}
}