diff --git a/.gitignore b/.gitignore index a877f04..686f153 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ # Ignore everything * +!/go-capnp # But not these files... !/.gitignore !*.go +*.capnp.go !go.sum !go.mod +!*.capnp !README.md !LICENSE diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e69de29 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..570ed6d --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +.PHONY: schema test + +test: schema + clear && go test + +schema: signal/signal.capnp.go + +%.capnp.go: %.capnp + capnp compile -I ./go-capnp/std -ogo $< diff --git a/context.go b/context.go index a839c19..c313383 100644 --- a/context.go +++ b/context.go @@ -1,124 +1,127 @@ package graphvent import ( - badger "github.com/dgraph-io/badger/v3" - "fmt" - "sync" - "errors" - "runtime" - "crypto/sha512" "crypto/ecdh" + "crypto/sha512" "encoding/binary" -) + "errors" + "fmt" + "reflect" + "runtime" + "sync" -// A Type can be Hashed by Hash -type TypeName interface { - String() string - Prefix() string -} + badger "github.com/dgraph-io/badger/v3" +) -// Hashed a Type to a uint64 -func Hash(t TypeName) uint64 { - hash := sha512.Sum512([]byte(fmt.Sprintf("%s%s", t.Prefix(), t.String()))) - return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) +func Hash(base string, name string) uint64 { + digest := append([]byte(base), 0x00) + digest = append(digest, []byte(name)...) + hash := sha512.Sum512(digest) + return binary.BigEndian.Uint64(hash[0:8]) } -// NodeType identifies the 'class' of a node -type NodeType string -func (node NodeType) Prefix() string { return "NODE: " } -func (node NodeType) String() string { return string(node) } +type ExtType uint64 +type NodeType uint64 +type SignalType uint64 +type PolicyType uint64 +type SerializedType uint64 -// ExtType identifies an extension on a node -type ExtType string -func (ext ExtType) Prefix() string { return "EXTENSION: " } -func (ext ExtType) String() string { return string(ext) } - -//Function to load an extension from bytes -type ExtensionLoadFunc func(*Context,[]byte) (Extension, error) -func LoadExtension[T any, E interface { - *T - Extension -}](ctx *Context, data []byte) (Extension, error) { - e := E(new(T)) - err := e.Deserialize(ctx, data) - if err != nil { - return nil, err - } +func NewExtType(name string) ExtType { + return ExtType(Hash(ExtTypeBase, name)) +} - return e, nil +func NewNodeType(name string) NodeType { + return NodeType(Hash(NodeTypeBase, name)) } -type PolicyType string -func (policy PolicyType) Prefix() string { return "POLICY: " } -func (policy PolicyType) String() string { return string(policy) } +func NewSignalType(name string) SignalType { + return SignalType(Hash(SignalTypeBase, name)) +} -type PolicyLoadFunc func(*Context,[]byte) (Policy, error) -func LoadPolicy[T any, P interface { - *T - Policy -}](ctx *Context, data []byte) (Policy, error) { - p := P(new(T)) - err := p.Deserialize(ctx, data) - if err != nil { - return nil, err - } - return p, nil +func NewPolicyType(name string) PolicyType { + return PolicyType(Hash(PolicyTypeBase, name)) } -type PolicyInfo struct { - Load PolicyLoadFunc - Type PolicyType +func NewSerializedType(name string) SerializedType { + return SerializedType(Hash(SerializedTypeBase, name)) } -// ExtType and NodeType constants const ( - ListenerExtType = ExtType("LISTENER") - LockableExtType = ExtType("LOCKABLE") - GQLExtType = ExtType("GQL") - GroupExtType = ExtType("GROUP") - ECDHExtType = ExtType("ECDH") - - GQLNodeType = NodeType("GQL") + ExtTypeBase = "ExtType" + NodeTypeBase = "NodeType" + SignalTypeBase = "SignalType" + PolicyTypeBase = "PolicyType" + SerializedTypeBase = "SerializedType" + FieldNameBase = "FieldName" ) var ( + ListenerExtType = NewExtType("LISTENER") + LockableExtType = NewExtType("LOCKABLE") + GQLExtType = NewExtType("GQL") + GroupExtType = NewExtType("GROUP") + ECDHExtType = NewExtType("ECDH") + + GQLNodeType = NewNodeType("GQL") + + StopSignalType = NewSignalType("STOP") + CreateSignalType = NewSignalType("CREATE") + StartSignalType = NewSignalType("START") + ErrorSignalType = NewSignalType("ERROR") + StatusSignalType = NewSignalType("STATUS") + LinkSignalType = NewSignalType("LINK") + LockSignalType = NewSignalType("LOCK") + ReadSignalType = NewSignalType("READ") + ReadResultSignalType = NewSignalType("READ_RESULT") + ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT") + + MemberOfPolicyType = NewPolicyType("USER_OF") + RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") + PerNodePolicyType = NewPolicyType("PER_NODE") + AllNodesPolicyType = NewPolicyType("ALL_NODES") + + StructType = NewSerializedType("struct") + SliceType = NewSerializedType("slice") + ArrayType = NewSerializedType("array") + PointerType = NewSerializedType("pointer") + MapType = NewSerializedType("map") + ErrorType = NewSerializedType("error") + ExtensionType = NewSerializedType("extension") + + StringType = NewSerializedType("string") + NodeKeyType = NewSerializedType("node_key") + NodeNotFoundError = errors.New("Node not found in DB") ECDH = ecdh.X25519() ) -type SignalLoadFunc func(*Context,[]byte) (Signal, error) -func LoadSignal[T any, S interface{ - *T - Signal -}](ctx *Context, data []byte) (Signal, error) { - s := S(new(T)) - err := s.Deserialize(ctx, data) - if err != nil { - return nil, err - } +type ExtensionInfo struct { + Type reflect.Type + Data interface{} +} - return s, nil +type NodeInfo struct { + Extensions []ExtType } -type SignalInfo struct { - Load SignalLoadFunc - Type SignalType +type TypeSerialize func(*Context,interface{}) ([]byte, error) +type TypeDeserialize func(*Context,[]byte) (interface{}, error) +type TypeInfo struct { + Type reflect.Type + Serialize TypeSerialize + Deserialize TypeDeserialize } -// Information about a registered extension -type ExtensionInfo struct { - // Function used to load extensions of this type from the database - Load ExtensionLoadFunc - Type ExtType - // Extra context data shared between nodes of this class - Data interface{} +type Int int +func (i Int) MarshalBinary() ([]byte, error) { + ret := make([]byte, 8) + binary.BigEndian.PutUint64(ret, uint64(i)) + return ret, nil } -// Information about a registered node type -type NodeInfo struct { - Type NodeType - // Required extensions to be a valid node of this class - Extensions []ExtType +type String string +func (str String) MarshalBinary() ([]byte, error) { + return []byte(str), nil } // A Context stores all the data to run a graphvent process @@ -128,101 +131,132 @@ type Context struct { // Logging interface Log Logger // Map between database extension hashes and the registered info - Extensions map[uint64]ExtensionInfo + Extensions map[ExtType]ExtensionInfo + ExtensionTypes map[reflect.Type]ExtType // Map between databse policy hashes and the registered info - Policies map[uint64]PolicyInfo + Policies map[PolicyType]reflect.Type + PolicyTypes map[reflect.Type]PolicyType // Map between serialized signal hashes and the registered info - Signals map[uint64]SignalInfo + Signals map[SignalType]reflect.Type + SignalTypes map[reflect.Type]SignalType // Map between database type hashes and the registered info - Types map[uint64]*NodeInfo + Nodes map[NodeType]NodeInfo + // Map between go types and registered info + Types map[SerializedType]TypeInfo + TypeReflects map[reflect.Type]SerializedType + // Routing map to all the nodes local to this context - NodesLock sync.RWMutex - Nodes map[NodeID]*Node + nodeMapLock sync.RWMutex + nodeMap map[NodeID]*Node } // Register a NodeType to the context, with the list of extensions it requires func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { - type_hash := Hash(node_type) - _, exists := ctx.Types[type_hash] + _, exists := ctx.Nodes[node_type] if exists == true { - return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type) + return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type) } ext_found := map[ExtType]bool{} for _, extension := range(extensions) { - _, in_ctx := ctx.Extensions[Hash(extension)] + _, in_ctx := ctx.Extensions[extension] if in_ctx == false { - return fmt.Errorf("Cannot register node type %s, required extension %s not in context", node_type, extension) + return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension) } _, duplicate := ext_found[extension] if duplicate == true { - return fmt.Errorf("Duplicate extension %s found in extension list", extension) + return fmt.Errorf("Duplicate extension %+v found in extension list", extension) } ext_found[extension] = true } - ctx.Types[type_hash] = &NodeInfo{ - Type: node_type, + ctx.Nodes[node_type] = NodeInfo{ Extensions: extensions, } return nil } -func RegisterSignal[T any, S interface { - *T - Signal -}](ctx *Context, signal_type SignalType) error { - type_hash := Hash(signal_type) - _, exists := ctx.Signals[type_hash] +func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type PolicyType) error { + _, exists := ctx.Policies[policy_type] if exists == true { - return fmt.Errorf("Cannot register signal of type %s, type already exists in context", signal_type) + return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type) } - ctx.Signals[type_hash] = SignalInfo{ - Load: LoadSignal[T, S], - Type: signal_type, + ctx.Policies[policy_type] = reflect_type + ctx.PolicyTypes[reflect_type] = policy_type + return nil +} + +func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalType) error { + _, exists := ctx.Signals[signal_type] + if exists == true { + return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type) } + + ctx.Signals[signal_type] = reflect_type + ctx.SignalTypes[reflect_type] = signal_type return nil } // Add a node to a context, returns an error if the def is invalid or already exists in the context -func RegisterExtension[T any, E interface{ - *T - Extension -}](ctx *Context, data interface{}) error { - var zero E - ext_type := zero.Type() - type_hash := Hash(ext_type) - _, exists := ctx.Extensions[type_hash] +func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType, data interface{}) error { + _, exists := ctx.Extensions[ext_type] if exists == true { - return fmt.Errorf("Cannot register extension of type %s, type already exists in context", ext_type) + return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type) } - ctx.Extensions[type_hash] = ExtensionInfo{ - Load: LoadExtension[T,E], - Type: ext_type, + ctx.Extensions[ext_type] = ExtensionInfo{ + Type: reflect_type, Data: data, } + ctx.ExtensionTypes[reflect_type] = ext_type + + return nil +} + +func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error { + _, exists := ctx.Types[ctx_type] + if exists == true { + return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type) + } + _, exists = ctx.TypeReflects[reflect_type] + if exists == true { + return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type) + } + if deserialize == nil { + return fmt.Errorf("Cannot register field without deserialize function") + } + if serialize == nil { + return fmt.Errorf("Cannot register field without serialize function") + } + + ctx.Types[ctx_type] = TypeInfo{ + Type: reflect_type, + Serialize: serialize, + Deserialize: deserialize, + } + ctx.TypeReflects[reflect_type] = ctx_type + return nil } func (ctx *Context) AddNode(id NodeID, node *Node) { - ctx.NodesLock.Lock() - ctx.Nodes[id] = node - ctx.NodesLock.Unlock() + ctx.nodeMapLock.Lock() + ctx.nodeMap[id] = node + ctx.nodeMapLock.Unlock() } func (ctx *Context) Node(id NodeID) (*Node, bool) { - ctx.NodesLock.RLock() - node, exists := ctx.Nodes[id] - ctx.NodesLock.RUnlock() + ctx.nodeMapLock.RLock() + node, exists := ctx.nodeMap[id] + ctx.nodeMapLock.RUnlock() return node, exists } // Get a node from the context, or load from the database if not loaded -func (ctx *Context) GetNode(id NodeID) (*Node, error) { +func (ctx *Context) getNode(id NodeID) (*Node, error) { target, exists := ctx.Node(id) if exists == false { @@ -241,7 +275,7 @@ func (ctx *Context) Send(messages Messages) error { if msg.Dest == ZeroID { panic("Can't send to null ID") } - target, err := ctx.GetNode(msg.Dest) + target, err := ctx.getNode(msg.Dest) if err == nil { select { case target.MsgChan <- msg: @@ -262,55 +296,311 @@ func (ctx *Context) Send(messages Messages) error { return nil } +type defaultKind struct { + Type SerializedType + Serialize func(interface{})([]byte, error) + Deserialize func([]byte)(interface{}, error) +} + +var defaultKinds = map[reflect.Kind]defaultKind{ + reflect.Int: { + Deserialize: func(data []byte)(interface{}, error){ + if len(data) != 8 { + return nil, fmt.Errorf("invalid length: %d/8", len(data)) + } + return int(binary.BigEndian.Uint64(data)), nil + }, + Serialize: func(val interface{})([]byte, error){ + i, ok := val.(int) + if ok == false { + return nil, fmt.Errorf("invalid type %+v", reflect.TypeOf(val)) + } else { + bytes := make([]byte, 8) + binary.BigEndian.PutUint64(bytes, uint64(i)) + return bytes, nil + } + }, + }, +} + +type SerializedValue struct { + TypeStack []uint64 + Data []byte +} + +func (field SerializedValue) MarshalBinary() ([]byte, error) { + data := []byte{} + for _, t := range(field.TypeStack) { + t_ser := make([]byte, 8) + binary.BigEndian.PutUint64(t_ser, uint64(t)) + data = append(data, t_ser...) + } + data = append(data, field.Data...) + return data, nil +} + +func RecurseTypes(ctx *Context, t reflect.Type) ([]uint64, []reflect.Kind, error) { + var ctx_type uint64 = 0x00 + ctype, exists := ctx.TypeReflects[t] + if exists == true { + ctx_type = uint64(ctype) + } + + var new_types []uint64 + var new_kinds []reflect.Kind + kind := t.Kind() + switch kind { + case reflect.Array: + if ctx_type == 0x00 { + ctx_type = uint64(ArrayType) + } + elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem()) + if err != nil { + return nil, nil, err + } + new_types = append(new_types, ctx_type) + new_types = append(new_types, elem_types...) + + new_kinds = append(new_kinds, reflect.Array) + new_kinds = append(new_kinds, elem_kinds...) + case reflect.Map: + if ctx_type == 0x00 { + ctx_type = uint64(MapType) + } + key_types, key_kinds, err := RecurseTypes(ctx, t.Key()) + if err != nil { + return nil, nil, err + } + elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem()) + if err != nil { + return nil, nil, err + } + new_types = append(new_types, ctx_type) + new_types = append(new_types, key_types...) + new_types = append(new_types, elem_types...) + + new_kinds = append(new_kinds, reflect.Map) + new_kinds = append(new_kinds, key_kinds...) + new_kinds = append(new_kinds, elem_kinds...) + case reflect.Slice: + if ctx_type == 0x00 { + ctx_type = uint64(SliceType) + } + elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem()) + if err != nil { + return nil, nil, err + } + new_types = append(new_types, ctx_type) + new_types = append(new_types, elem_types...) + + new_kinds = append(new_kinds, reflect.Slice) + new_kinds = append(new_kinds, elem_kinds...) + case reflect.Pointer: + if ctx_type == 0x00 { + ctx_type = uint64(PointerType) + } + elem_types, elem_kinds, err := RecurseTypes(ctx, t.Elem()) + if err != nil { + return nil, nil, err + } + new_types = append(new_types, ctx_type) + new_types = append(new_types, elem_types...) + + new_kinds = append(new_kinds, reflect.Pointer) + new_kinds = append(new_kinds, elem_kinds...) + case reflect.String: + if ctx_type == 0x00 { + ctx_type = uint64(StringType) + } + new_types = append(new_types, ctx_type) + new_kinds = append(new_kinds, reflect.String) + default: + return nil, nil, fmt.Errorf("unhandled kind: %+v - %+v", kind, t) + } + return new_types, new_kinds, nil +} + +func serializeValue(ctx *Context, kind_stack []reflect.Kind, value reflect.Value) ([]byte, error) { + kind := kind_stack[len(kind_stack) - 1] + switch kind { + default: + return nil, fmt.Errorf("unhandled kind: %+v", kind) + } +} + +func SerializeValue(ctx *Context, value reflect.Value) (SerializedValue, error) { + if value.IsValid() == false { + return SerializedValue{}, fmt.Errorf("Cannot serialize invalid value: %+v", value) + } + + type_stack, kind_stack, err := RecurseTypes(ctx, value.Type()) + if err != nil { + return SerializedValue{}, err + } + + bytes, err := serializeValue(ctx, kind_stack, value) + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{ + type_stack, + bytes, + }, nil +} + +/* + default: + kind_def, handled := defaultKinds[kind] + if handled == false { + ctx_type, handled := ctx.TypeReflects[value.Type()] + if handled == false { + err = fmt.Errorf("%+v is not a handled reflect type", value.Type()) + break + } + type_info, handled := ctx.Types[ctx_type] + if handled == false { + err = fmt.Errorf("%+v is not a handled reflect type(INTERNAL_ERROR)", value.Type()) + break + } + field_ser, err := type_info.Serialize(ctx, value.Interface()) + if err != nil { + err = fmt.Errorf(err.Error()) + break + } + ret = SerializedValue{ + []uint64{uint64(ctx_type)}, + field_ser, + } + } + field_ser, err := kind_def.Serialize(value.Interface()) + if err != nil { + err = fmt.Errorf(err.Error()) + } else { + ret = SerializedValue{ + []uint64{uint64(kind_def.Type)}, + field_ser, + } + } +*/ + +func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { + if ext == nil { + return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension") + } + ext_value := reflect.ValueOf(ext).Elem() + field := ext_value.FieldByName(field_name) + if field.IsValid() == false { + return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext) + } else { + return SerializeValue(ctx, field) + } +} + +func SerializeSignal(ctx *Context, signal Signal, ctx_type SignalType) (SerializedValue, error) { + return SerializedValue{}, nil +} + +func SerializeExtension(ctx *Context, ext Extension, ctx_type ExtType) (SerializedValue, error) { + if ext == nil { + return SerializedValue{}, fmt.Errorf("Cannot serialize nil Extension ") + } + ext_type := reflect.TypeOf(ext).Elem() + ext_value := reflect.ValueOf(ext).Elem() + + m := map[string]SerializedValue{} + for _, field := range(reflect.VisibleFields(ext_type)) { + ext_tag, tagged_ext := field.Tag.Lookup("ext") + if tagged_ext == false { + continue + } else { + field_value := ext_value.FieldByIndex(field.Index) + var err error + m[ext_tag], err = SerializeValue(ctx, field_value) + if err != nil { + return SerializedValue{}, err + } + } + } + map_value := reflect.ValueOf(m) + map_ser, err := SerializeValue(ctx, map_value) + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{ + append([]uint64{uint64(ctx_type)}, map_ser.TypeStack...), + map_ser.Data, + }, nil +} + +func DeserializeValue(ctx *Context, value SerializedValue) (interface{}, error) { + // TODO: do the opposite of SerializeValue. + // 1) Check the type to handle special types(array, list, map, pointer) + // 2) Check if the type is registered in the context, handle if so + // 3) Check if the type is a default type, handle if so + // 4) Return error if we don't know how to deserialize the type + return nil, fmt.Errorf("Undefined") +} + // Create a new Context with the base library content added func NewContext(db * badger.DB, log Logger) (*Context, error) { ctx := &Context{ DB: db, Log: log, - Extensions: map[uint64]ExtensionInfo{}, - Types: map[uint64]*NodeInfo{}, - Signals: map[uint64]SignalInfo{}, - Nodes: map[NodeID]*Node{}, + Policies: map[PolicyType]reflect.Type{}, + PolicyTypes: map[reflect.Type]PolicyType{}, + Extensions: map[ExtType]ExtensionInfo{}, + ExtensionTypes: map[reflect.Type]ExtType{}, + Signals: map[SignalType]reflect.Type{}, + SignalTypes: map[reflect.Type]SignalType{}, + Nodes: map[NodeType]NodeInfo{}, + nodeMap: map[NodeID]*Node{}, + Types: map[SerializedType]TypeInfo{}, + TypeReflects: map[reflect.Type]SerializedType{}, } var err error - err = RegisterExtension[LockableExt,*LockableExt](ctx, nil) + err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) if err != nil { return nil, err } - err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil) + err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil) if err != nil { return nil, err } - err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil) + err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil) if err != nil { return nil, err } - err = RegisterExtension[GroupExt,*GroupExt](ctx, nil) + gql_ctx := NewGQLExtContext() + err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx) if err != nil { return nil, err } - gql_ctx := NewGQLExtContext() - err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx) + err = ctx.RegisterSignal(reflect.TypeOf((*StopSignal)(nil)), StopSignalType) + if err != nil { + return nil, err + } + + err = ctx.RegisterSignal(reflect.TypeOf((*CreateSignal)(nil)), CreateSignalType) if err != nil { return nil, err } - err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType) + err = ctx.RegisterSignal(reflect.TypeOf((*StartSignal)(nil)), StartSignalType) if err != nil { return nil, err } - err = RegisterSignal[BaseSignal, *BaseSignal](ctx, NewSignalType) + err = ctx.RegisterSignal(reflect.TypeOf((*ReadSignal)(nil)), ReadSignalType) if err != nil { return nil, err } - err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StartSignalType) + err = ctx.RegisterSignal(reflect.TypeOf((*ReadResultSignal)(nil)), ReadResultSignalType) if err != nil { return nil, err } diff --git a/ecdh.go b/ecdh.go deleted file mode 100644 index 028c99b..0000000 --- a/ecdh.go +++ /dev/null @@ -1,167 +0,0 @@ -package graphvent - -import ( - "fmt" - "time" - "encoding/json" - "crypto/ecdsa" - "crypto/x509" - "crypto/ecdh" -) - -type ECDHState struct { - ECKey *ecdh.PrivateKey - SharedSecret []byte -} - -type ECDHStateJSON struct { - ECKey []byte `json:"ec_key"` - SharedSecret []byte `json:"shared_secret"` -} - -func (state *ECDHState) MarshalJSON() ([]byte, error) { - var key_bytes []byte - var err error - if state.ECKey != nil { - key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey) - if err != nil { - return nil, err - } - } - - return json.Marshal(&ECDHStateJSON{ - ECKey: key_bytes, - SharedSecret: state.SharedSecret, - }) -} - -func (state *ECDHState) UnmarshalJSON(data []byte) error { - var j ECDHStateJSON - err := json.Unmarshal(data, &j) - if err != nil { - return err - } - - state.SharedSecret = j.SharedSecret - if len(j.ECKey) == 0 { - state.ECKey = nil - } else { - tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey) - if err != nil { - return err - } - - ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey) - if ok == false { - return fmt.Errorf("Parsed wrong key type from DB for ECDHState") - } - - state.ECKey, err = ecdsa_key.ECDH() - if err != nil { - return err - } - } - - return nil -} - -type ECDHMap map[NodeID]ECDHState - -func (m ECDHMap) MarshalJSON() ([]byte, error) { - tmp := map[string]ECDHState{} - for id, state := range(m) { - tmp[id.String()] = state - } - - return json.Marshal(tmp) -} - -type ECDHExt struct { - ECDHStates ECDHMap -} - -func NewECDHExt() *ECDHExt { - return &ECDHExt{ - ECDHStates: ECDHMap{}, - } -} - -func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} { - var zero T - field_func, ok := field_funcs[name] - if ok == false { - return fmt.Errorf("%s is not a field of %s", name, zero.Type()) - } - return field_func(t) -} - -func (ext *ECDHExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{ - "ecdh_states": func(ext *ECDHExt) interface{} { - return ext.ECDHStates - }, - }) -} - -func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) Messages { - source := KeyID(signal.EDDSA) - - messages := Messages{} - switch signal.Str { - case "req": - state, exists := ext.ECDHStates[source] - if exists == false { - state = ECDHState{nil, nil} - } - resp, shared_secret, err := NewECDHRespSignal(node, signal) - if err == nil { - state.SharedSecret = shared_secret - ext.ECDHStates[source] = state - log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - messages = messages.Add(node.ID, node.Key, &resp, source) - } else { - log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source) - } - case "resp": - state, exists := ext.ECDHStates[source] - if exists == false || state.ECKey == nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source) - } else { - err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) - if err == nil { - shared_secret, err := state.ECKey.ECDH(signal.ECDH) - if err == nil { - state.SharedSecret = shared_secret - state.ECKey = nil - ext.ECDHStates[source] = state - log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - } - } - } - default: - log.Logf("ecdh", "unknown echd state %s", signal.Str) - } - return messages -} - -func (ext *ECDHExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { - switch signal.Type() { - case ECDHSignalType: - sig := signal.(*ECDHSignal) - return ext.HandleECDHSignal(ctx.Log, node, sig) - } - return nil -} - -func (ext *ECDHExt) Type() ExtType { - return ECDHExtType -} - -func (ext *ECDHExt) Serialize() ([]byte, error) { - return json.Marshal(ext) -} - -func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error { - return json.Unmarshal(data, &ext) -} diff --git a/go.mod b/go.mod index ab994a3..05f8319 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/mekkanized/graphvent -go 1.20 +go 1.21.0 + +replace github.com/mekkanized/graphvent/signal v0.0.0 => ./signal require ( github.com/dgraph-io/badger/v3 v3.2103.5 @@ -11,6 +13,7 @@ require ( ) require ( + capnproto.org/go/capnp/v3 v3.0.0-alpha-29 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/dgraph-io/badger/v4 v4.1.0 // indirect @@ -28,8 +31,11 @@ require ( github.com/klauspost/compress v1.12.3 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mekkanized/graphvent/signal v0.0.0 // indirect github.com/pkg/errors v0.9.1 // indirect go.opencensus.io v0.22.5 // indirect golang.org/x/net v0.7.0 // indirect + golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect golang.org/x/sys v0.6.0 // indirect + zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect ) diff --git a/go.sum b/go.sum index 07f7de7..2122dfe 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +capnproto.org/go/capnp/v3 v3.0.0-alpha-29 h1:ICLhiy4Jmp0d7hLQO+HzFAVIft/oxpPAUPV8tqx+eUE= +capnproto.org/go/capnp/v3 v3.0.0-alpha-29/go.mod h1:+ysMHvOh1EWNOyorxJWs1omhRFiDoKxKkWQACp54jKM= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -134,6 +136,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -171,3 +174,5 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 h1:yksDCGMVzyn3vlyf0GZ3huiF5FFaMGQpQ3UJvR0EoGA= +zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5/go.mod h1:1LtNdPAs8WH+BTcQiZAOo2MIKD/5jyK/u7sZ9ZPe5SE= diff --git a/gql.go b/gql.go index c6ffc3f..598e5fc 100644 --- a/gql.go +++ b/gql.go @@ -205,7 +205,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ if err != nil { return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes) } - auth_id := NodeID{auth_uuid} + auth_id := NodeID(auth_uuid) key_bytes, err := base64.StdEncoding.DecodeString(key_b64) if err != nil { @@ -234,7 +234,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ Ext: gql_ext, Chans: map[uuid.UUID]chan Signal{}, Context: ctx, - GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), + GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext), Server: server, User: key_id, Key: key, @@ -270,7 +270,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW query := GQLPayload{} json.Unmarshal(str, &query) - gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) + gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, @@ -401,7 +401,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon } } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) - gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) + gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, Context: req_ctx, @@ -543,7 +543,7 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { return graphql.NewSchema(schemaConfig) } -func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, T)(interface{}, error)) error { +func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, SerializedValue)(interface{}, error)) error { if ctx == nil { return fmt.Errorf("ctx is nil") } @@ -561,21 +561,19 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { ext, exists := result.Result.Extensions[ext_type] if exists == false { - return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) + return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type) } - val_if, exists := ext[acl_name] + val_ser, exists := ext[acl_name] if exists == false { - return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type) + return nil, fmt.Errorf("%s is not in the fields of %+v in the result", acl_name, ext_type) } - var zero T - val, ok := val_if.(T) - if ok == false { - return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero)) + if val_ser.TypeStack[0] == uint64(ErrorType) { + return nil, fmt.Errorf(string(val_ser.Data)) } - return resolve_fn(p, val) + return resolve_fn(p, val_ser) }) } @@ -681,8 +679,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in for field_name, field := range(self_fields) { self_field := field - err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, - func(p graphql.ResolveParams, val interface{})(interface{}, error) { + err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, + func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { ctx, err := PrepResolve(p) if err != nil { return nil, err @@ -715,7 +713,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in for field_name, field := range(list_fields) { list_field := field - resolve_fn := func(p graphql.ResolveParams, val interface{})(interface{}, error) { + resolve_fn := func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { ctx, err := PrepResolve(p) if err != nil { return nil, err @@ -736,7 +734,7 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in return nodes, nil } - err := RegisterField(ctx, ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn) + err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn) if err != nil { return err } @@ -764,7 +762,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte _, exists := ctx.NodeTypes[node_type] if exists == true { - return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type) + return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type) } node_interfaces, err := GQLInterfaces(ctx, interface_names) @@ -830,19 +828,20 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - err = RegisterField(&context, context.Interfaces["Node"].List, "Members", GroupExtType, "members", - func(p graphql.ResolveParams, val map[NodeID]string)(interface{}, error) { + err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members", + func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { ctx, err := PrepResolve(p) if err != nil { return nil, err } - node_list := make([]NodeID, len(val)) + /*node_list := make([]NodeID, len(val)) i := 0 for id, _ := range(val) { node_list[i] = id i += 1 } - + */ + node_list := []NodeID{} nodes, err := ResolveNodes(ctx, p, node_list) if err != nil { return nil, err @@ -895,7 +894,7 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - err = RegisterField(&context, graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen string) (interface{}, error) { + err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen SerializedValue) (interface{}, error) { return listen, nil }) if err != nil { @@ -1000,14 +999,6 @@ type GQLExt struct { Listen string `json:"listen"` } -func (ext *GQLExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*GQLExt)interface{}{ - "listen": func(ext *GQLExt) interface{} { - return ext.Listen - }, - }) -} - func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { ext.resolver_response_lock.RLock() response_chan, _ := ext.resolver_response[req_id] @@ -1036,11 +1027,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { // Process ReadResultSignalType by forwarding it to the waiting resolver - messages := Messages{} - if signal.Type() == ErrorSignalType { + switch sig := signal.(type) { + case *ErrorSignal: // TODO: Forward to resolver if waiting for it - sig := signal.(*ErrorSignal) - response_chan := ext.FreeResponseChannel(sig.ReqID()) + response_chan := ext.FreeResponseChannel(sig.Header().ReqID) if response_chan != nil { select { case response_chan <- sig: @@ -1052,9 +1042,8 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } else { ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig) } - } else if signal.Type() == ReadResultSignalType { - sig := signal.(*ReadResultSignal) - response_chan := ext.FindResponseChannel(sig.ReqID()) + case *ReadResultSignal: + response_chan := ext.FindResponseChannel(sig.ReqID) if response_chan != nil { select { case response_chan <- sig: @@ -1065,23 +1054,23 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } else { ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) } - } else if signal.Type() == StartSignalType { + case *StartSignal: ctx.Log.Logf("gql", "starting gql server %s", node.ID) err := ext.StartGQLServer(ctx, node) if err == nil { - node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) + node.QueueSignal(time.Now(), NewStatusSignal(node.ID, "server_started")) } else { ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } } - return messages + return nil } func (ext *GQLExt) Type() ExtType { return GQLExtType } -func (ext *GQLExt) Serialize() ([]byte, error) { +func (ext *GQLExt) MarshalBinary() ([]byte, error) { return json.Marshal(ext) } diff --git a/gql_interfaces.go b/gql_interfaces.go index abae636..d213d2f 100644 --- a/gql_interfaces.go +++ b/gql_interfaces.go @@ -36,7 +36,7 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy return false } - node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)] + node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType] if exists == false { return false } else { @@ -73,7 +73,7 @@ func NodeInterfaceResolveType(required_extensions []ExtType, default_type **grap gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType] ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type) if exists == false { - node_type_def, exists := ctx.Context.Types[Hash(node.Result.NodeType)] + node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType] if exists == false { return nil } else { diff --git a/gql_query.go b/gql_query.go index 20070e3..bc936b0 100644 --- a/gql_query.go +++ b/gql_query.go @@ -51,16 +51,16 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) msgs := Messages{} - msgs = msgs.Add(ctx.Server.ID, ctx.Key, read_signal, id) + msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id) - response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) - resp_channels[read_signal.ID()] = response_chan - node_ids[read_signal.ID()] = id + response_chan := ctx.Ext.GetResponseChannel(read_signal.ID) + resp_channels[read_signal.ID] = response_chan + node_ids[read_signal.ID] = id // TODO: Send all at once instead of createing Messages for each err = ctx.Context.Send(msgs) if err != nil { - ctx.Ext.FreeResponseChannel(read_signal.ID()) + ctx.Ext.FreeResponseChannel(read_signal.ID) return nil, err } } @@ -68,8 +68,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ responses := []NodeResult{} for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout - response, err := WaitForSignal(response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ - return sig.ReqID() == sig_id + response, err := WaitForSignal(response_chan, time.Millisecond*100, func(sig *ReadResultSignal)bool{ + return sig.ReqID == sig_id }) if err != nil { return nil, err diff --git a/gql_resolvers.go b/gql_resolvers.go index 658af9e..ccb2d8f 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -81,6 +81,6 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { - return Hash(node.Result.NodeType), nil + return uint64(node.Result.NodeType), nil }) } diff --git a/gql_test.go b/gql_test.go index 7bac95e..f157db6 100644 --- a/gql_test.go +++ b/gql_test.go @@ -21,7 +21,7 @@ import ( func TestGQLServer(t *testing.T) { ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"}) - TestNodeType := NodeType("TEST") + TestNodeType := NewNodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) fatalErr(t, err) @@ -30,33 +30,33 @@ func TestGQLServer(t *testing.T) { gql_id := KeyID(pub) group_policy_1 := NewAllNodesPolicy(Tree{ - ReadSignalType.String(): Tree{ - GroupExtType.String(): Tree{ - "members": Tree{}, + uint64(ReadSignalType): Tree{ + uint64(GroupExtType): Tree{ + Hash(FieldNameBase, "members"): Tree{}, }, }, - ReadResultSignalType.String(): nil, - ErrorSignalType.String(): nil, + uint64(ReadResultSignalType): nil, + uint64(ErrorSignalType): nil, }) group_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ gql_id: Tree{ - LinkSignalType.String(): nil, - LockSignalType.String(): nil, - StatusSignalType.String(): nil, - ReadSignalType.String(): nil, + uint64(LinkSignalType): nil, + uint64(LockSignalType): nil, + uint64(StatusSignalType): nil, + uint64(ReadSignalType): nil, }, }) user_policy_1 := NewAllNodesPolicy(Tree{ - ReadResultSignalType.String(): nil, - ErrorSignalType.String(): nil, + uint64(ReadResultSignalType): nil, + uint64(ErrorSignalType): nil, }) user_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ gql_id: Tree{ - LinkSignalType.String(): nil, - ReadSignalType.String(): nil, + uint64(LinkSignalType): nil, + uint64(ReadSignalType): nil, }, }) @@ -80,8 +80,8 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "server_started" + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "server_started" }) fatalErr(t, err) @@ -107,7 +107,9 @@ func TestGQLServer(t *testing.T) { }, } - auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize()) + n1_id_bytes, err := n1.ID.MarshalBinary() + fatalErr(t, err) + auth_username := base64.StdEncoding.EncodeToString(n1_id_bytes) key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) fatalErr(t, err) auth_password := base64.StdEncoding.EncodeToString(key_bytes) @@ -196,11 +198,11 @@ func TestGQLServer(t *testing.T) { SubGQL(sub_1) msgs := Messages{} - msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) + msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "stopped" + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "stopped" }) fatalErr(t, err) } @@ -208,7 +210,7 @@ func TestGQLServer(t *testing.T) { func TestGQLDB(t *testing.T) { ctx := logTestContext(t, []string{"test"}) - TestUserNodeType := NodeType("TEST_USER") + TestUserNodeType := NewNodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) fatalErr(t, err) u1 := NewNode(ctx, nil, TestUserNodeType, 10, nil) @@ -225,33 +227,31 @@ func TestGQLDB(t *testing.T) { ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) msgs := Messages{} - msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) + msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "stopped" && sig.NodeID == gql.ID + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "stopped" && sig.Source == gql.ID }) fatalErr(t, err) - ser1, err := gql.Serialize() - ser2, err := u1.Serialize() - ser3, err := StopSignal.Serialize() + ser1, err := gql.Serialize(ctx) + ser2, err := u1.Serialize(ctx) ctx.Log.Logf("test", "SER_1: \n%s\n\n", ser1) ctx.Log.Logf("test", "SER_2: \n%s\n\n", ser2) - ctx.Log.Logf("test", "SER_3: \n%s\n\n", ser3) // Clear all loaded nodes from the context so it loads them from the database - ctx.Nodes = map[NodeID]*Node{} + ctx.nodeMap = map[NodeID]*Node{} gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) - listener_ext, err = GetExt[*ListenerExt](gql_loaded) + listener_ext, err = GetExt[*ListenerExt](gql_loaded, GQLExtType) fatalErr(t, err) msgs = Messages{} - msgs = msgs.Add(gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) + msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "stopped" && sig.Source == gql_loaded.ID }) fatalErr(t, err) } diff --git a/gql_types.go b/gql_types.go index f63678a..38ae233 100644 --- a/gql_types.go +++ b/gql_types.go @@ -16,10 +16,10 @@ func AddNodeFields(object *graphql.Object) { }) } -func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type { +func NewGQLNodeType(gql_name string, node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type { var gql Type gql.Type = graphql.NewObject(graphql.ObjectConfig{ - Name: string(node_type), + Name: gql_name, Interfaces: interfaces, IsTypeOf: func(p graphql.IsTypeOfParams) bool { node, ok := p.Value.(NodeResult) diff --git a/graph_test.go b/graph_test.go index f90a7bc..7481fd4 100644 --- a/graph_test.go +++ b/graph_test.go @@ -6,7 +6,7 @@ import ( badger "github.com/dgraph-io/badger/v3" ) -const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") +var SimpleListenerNodeType = NewNodeType("SIMPLE_LISTENER") func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { listener_extension := NewListenerExt(buffer) diff --git a/group.go b/group.go index e784fbe..792c3b0 100644 --- a/group.go +++ b/group.go @@ -12,18 +12,10 @@ func (ext *GroupExt) Type() ExtType { return GroupExtType } -func (ext *GroupExt) Serialize() ([]byte, error) { +func (ext *GroupExt) MarshalBinary() ([]byte, error) { return json.Marshal(ext) } -func (ext *GroupExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*GroupExt)interface{}{ - "members": func(ext *GroupExt) interface{} { - return ext.Members - }, - }) -} - func NewGroupExt(members map[NodeID]string) *GroupExt { if members == nil { members = map[NodeID]string{} diff --git a/listener.go b/listener.go index 10d026c..603585c 100644 --- a/listener.go +++ b/listener.go @@ -18,19 +18,8 @@ func NewListenerExt(buffer int) *ListenerExt { } } -func (ext *ListenerExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*ListenerExt)interface{}{ - "buffer": func(ext *ListenerExt) interface{} { - return ext.Buffer - }, - "chan": func(ext *ListenerExt) interface{} { - return ext.Chan - }, - }) -} - // Simple load function, unmarshal the buffer int from json -func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { +func (ext *ListenerExt) DeserializeListenerExt(ctx *Context, data []byte) error { err := json.Unmarshal(data, &ext.Buffer) ext.Chan = make(chan Signal, ext.Buffer) return err @@ -51,6 +40,6 @@ func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal return nil } -func (ext *ListenerExt) Serialize() ([]byte, error) { +func (ext *ListenerExt) MarshalBinary() ([]byte, error) { return json.Marshal(ext.Buffer) } diff --git a/lockable.go b/lockable.go index c3ff000..ccacfa8 100644 --- a/lockable.go +++ b/lockable.go @@ -1,7 +1,6 @@ package graphvent import ( - "encoding/binary" "github.com/google/uuid" ) @@ -15,119 +14,17 @@ const ( ) type LockableExt struct{ - State ReqState - ReqID uuid.UUID - Owner *NodeID - PendingOwner *NodeID - Requirements map[NodeID]ReqState -} - -func (ext *LockableExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*LockableExt)interface{}{ - "owner": func(ext *LockableExt) interface{} { - return ext.Owner - }, - "pending_owner": func(ext *LockableExt) interface{} { - return ext.PendingOwner - }, - "requirements": func(ext *LockableExt) interface{} { - return ext.Requirements - }, - }) + State ReqState `ext:""` + ReqID *uuid.UUID `ext:""` + Owner *NodeID `ext:""` + PendingOwner *NodeID `ext:""` + Requirements map[NodeID]ReqState `ext:""` } func (ext *LockableExt) Type() ExtType { return LockableExtType } -func (ext *LockableExt) Serialize() ([]byte, error) { - ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements))) - if ext.Owner != nil { - bytes, err := ext.Owner.MarshalBinary() - if err != nil { - return nil, err - } - copy(ret[0:16], bytes) - } - - if ext.PendingOwner != nil { - bytes, err := ext.PendingOwner.MarshalBinary() - if err != nil { - return nil, err - } - copy(ret[16:32], bytes) - } - - binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements))) - ret[40] = byte(ext.State) - cur := 41 - for req, state := range(ext.Requirements) { - bytes, err := req.MarshalBinary() - if err != nil { - return nil, err - } - copy(ret[cur:cur+16], bytes) - ret[cur+16] = byte(state) - cur += 17 - } - - return ret, nil -} - -func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error { - cur := 0 - all_zero := true - for _, b := range(data[cur:cur+16]) { - if all_zero == true && b != 0x00 { - all_zero = false - } - } - if all_zero == false { - tmp, err := IDFromBytes(data[cur:cur+16]) - if err != nil { - return err - } - ext.Owner = &tmp - } - cur += 16 - - all_zero = true - for _, b := range(data[cur:cur+16]) { - if all_zero == true && b != 0x00 { - all_zero = false - } - } - if all_zero == false { - tmp, err := IDFromBytes(data[cur:cur+16]) - if err != nil { - return err - } - ext.PendingOwner = &tmp - } - cur += 16 - - num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8])) - cur += 8 - - ext.State = ReqState(data[cur]) - cur += 1 - - if num_requirements != 0 { - ext.Requirements = map[NodeID]ReqState{} - } - for i := 0; i < num_requirements; i++ { - id, err := IDFromBytes(data[cur:cur+16]) - if err != nil { - return err - } - cur += 16 - state := ReqState(data[cur]) - cur += 1 - ext.Requirements[id] = state - } - return nil -} - func NewLockableExt(requirements []NodeID) *LockableExt { var reqs map[NodeID]ReqState = nil if requirements != nil { @@ -148,21 +45,21 @@ func NewLockableExt(requirements []NodeID) *LockableExt { func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { msgs := Messages{} signal := NewLockSignal("unlock") - msgs = msgs.Add(owner.ID, owner.Key, signal, target) - return signal.ID(), ctx.Send(msgs) + msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) + return signal.Header().ID, ctx.Send(msgs) } // Send the signal to lock a node from itself func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { msgs := Messages{} signal := NewLockSignal("lock") - msgs = msgs.Add(owner.ID, owner.Key, signal, target) - return signal.ID(), ctx.Send(msgs) + msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) + return signal.Header().ID, ctx.Send(msgs) } -func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages { +func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) Messages { str := signal.Error - log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str) + ctx.Log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str) msgs := Messages {} switch str { @@ -173,7 +70,7 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, for id, state := range(ext.Requirements) { if state == Locked { ext.Requirements[id] = Unlocking - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id) + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), id) } } } @@ -185,51 +82,48 @@ func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, return msgs } -func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages { - id := signal.NodeID - action := signal.Str +func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) Messages { msgs := Messages {} if ext.State == Unlocked { - switch action { + switch signal.Action { case "add": - _, exists := ext.Requirements[id] + _, exists := ext.Requirements[signal.NodeID] if exists == true { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "already_requirement"), source) } else { if ext.Requirements == nil { ext.Requirements = map[NodeID]ReqState{} } - ext.Requirements[id] = Unlocked - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_added"), source) + ext.Requirements[signal.NodeID] = Unlocked + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_added"), source) } case "remove": - _, exists := ext.Requirements[id] + _, exists := ext.Requirements[signal.NodeID] if exists == false { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source) } else { - delete(ext.Requirements, id) - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "req_removed"), source) + delete(ext.Requirements, signal.NodeID) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_removed"), source) } default: - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "unknown_action"), source) } } else { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source) } return msgs } // Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { - state := signal.Str - log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state) +func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) Messages { + ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal.State) msgs := Messages{} - switch state { + switch signal.State { case "locked": state, found := ext.Requirements[source] if found == false { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source) } else if state == Locking { if ext.State == Locking { ext.Requirements[source] = Locked @@ -245,19 +139,19 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, if locked == reqs { ext.State = Locked ext.Owner = ext.PendingOwner - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) } else { - log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) + ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) } } else if ext.State == AbortingLock { ext.Requirements[source] = Unlocking - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), source) } } case "unlocked": state, found := ext.Requirements[source] if found == false { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source) } else if state == Unlocking { ext.Requirements[source] = Unlocked reqs := 0 @@ -274,13 +168,14 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, ext.State = Unlocked if old_state == Unlocking { ext.Owner = ext.PendingOwner - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner) + ext.ReqID = nil + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner) } else if old_state == AbortingLock { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner) + msgs = msgs.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner) ext.PendingOwner = ext.Owner } } else { - log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) + ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) } } case "lock": @@ -290,23 +185,24 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, new_owner := source ext.PendingOwner = &new_owner ext.Owner = &new_owner - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), new_owner) } else { ext.State = Locking - ext.ReqID = signal.ID() + id := signal.ID + ext.ReqID = &id new_owner := source ext.PendingOwner = &new_owner for id, state := range(ext.Requirements) { if state != Unlocked { - log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") + ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") } ext.Requirements[id] = Locking lock_signal := NewLockSignal("lock") - msgs = msgs.Add(node.ID, node.Key, lock_signal, id) + msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id) } } } else { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source) } case "unlock": if ext.State == Locked { @@ -315,25 +211,26 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, new_owner := source ext.PendingOwner = nil ext.Owner = nil - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), new_owner) } else if source == *ext.Owner { ext.State = Unlocking - ext.ReqID = signal.ID() + id := signal.ID + ext.ReqID = &id ext.PendingOwner = nil for id, state := range(ext.Requirements) { if state != Locked { - log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") + ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") } ext.Requirements[id] = Unlocking lock_signal := NewLockSignal("unlock") - msgs = msgs.Add(node.ID, node.Key, lock_signal, id) + msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id) } } } else { - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_locked"), source) } default: - log.Logf("lockable", "LOCK_ERR: unkown state %s", state) + ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) } return msgs } @@ -342,25 +239,25 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { messages := Messages{} - switch signal.Direction() { + switch signal.Header().Direction { case Up: if ext.Owner != nil { if *ext.Owner != node.ID { - messages = messages.Add(node.ID, node.Key, signal, *ext.Owner) + messages = messages.Add(ctx, node.ID, node.Key, signal, *ext.Owner) } } case Down: for requirement, _ := range(ext.Requirements) { - messages = messages.Add(node.ID, node.Key, signal, requirement) + messages = messages.Add(ctx, node.ID, node.Key, signal, requirement) } case Direct: - switch signal.Type() { - case LinkSignalType: - messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*IDStringSignal)) - case LockSignalType: - messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) - case ErrorSignalType: - messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal)) + switch sig := signal.(type) { + case *LinkSignal: + messages = ext.HandleLinkSignal(ctx, node, source, sig) + case *LockSignal: + messages = ext.HandleLockSignal(ctx, node, source, sig) + case *ErrorSignal: + messages = ext.HandleErrorSignal(ctx, node, source, sig) default: } default: diff --git a/lockable_test.go b/lockable_test.go index e10cb90..1a0242d 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -7,7 +7,7 @@ import ( "crypto/rand" ) -const TestLockableType = NodeType("TEST_LOCKABLE") +var TestLockableType = NewNodeType("TEST_LOCKABLE") func lockableTestContext(t *testing.T, logs []string) *Context { ctx := logTestContext(t, logs) @@ -43,57 +43,24 @@ func TestLink(t *testing.T) { ) msgs := Messages{} - msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID) + msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("add", l2.ID), l1.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool { return sig.Error == "req_added" }) fatalErr(t, err) msgs = Messages{} - s := NewBaseSignal("TEST", Down) - msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID) + msgs = msgs.Add(ctx, l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { - return sig.ID() == s.ID() - }) - fatalErr(t, err) - - _, err = WaitForSignal(l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { - return sig.ID() == s.ID() - }) - fatalErr(t, err) - - msgs = Messages{} - msgs = msgs.Add(l1.ID, l1.Key, NewLinkSignal("remove", l2.ID), l1.ID) - err = ctx.Send(msgs) - fatalErr(t, err) - - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool { return sig.Error == "req_removed" }) fatalErr(t, err) - - msgs = Messages{} - s = NewBaseSignal("TEST", Down) - msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID) - err = ctx.Send(msgs) - fatalErr(t, err) - - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { - return sig.ID() == s.ID() - }) - fatalErr(t, err) - - select { - case <- l2_listener.Chan: - t.Fatal("Recevied message on l2 after removing link") - default: - } } func Test10KLink(t *testing.T) { @@ -104,7 +71,7 @@ func Test10KLink(t *testing.T) { listener_id := KeyID(l_pub) child_policy := NewPerNodePolicy(map[NodeID]Tree{ listener_id: Tree{ - LockSignalType.String(): nil, + uint64(LockSignalType): nil, }, }) NewLockable := func()(*Node) { @@ -125,7 +92,7 @@ func Test10KLink(t *testing.T) { ctx.Log.Logf("test", "CREATED_10K") l_policy := NewAllNodesPolicy(Tree{ - LockSignalType.String(): nil, + uint64(LockSignalType): nil, }) listener := NewListenerExt(100000) node := NewNode(ctx, listener_key, TestLockableType, 10000, @@ -140,14 +107,14 @@ func Test10KLink(t *testing.T) { _, err = LockLockable(ctx, node, node.ID) fatalErr(t, err) - _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { - return sig.Str == "locked" + _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, func(sig *LockSignal) bool { + return sig.State == "locked" }) fatalErr(t, err) for _, _ = range(reqs) { - _, err := WaitForSignal(listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool { - return sig.Str == "locked" + _, err := WaitForSignal(listener.Chan, time.Millisecond*100, func(sig *LockSignal) bool { + return sig.State == "locked" }) fatalErr(t, err) } @@ -178,36 +145,36 @@ func TestLock(t *testing.T) { l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) - locked := func(sig *StringSignal) bool { - return sig.Str == "locked" + locked := func(sig *LockSignal) bool { + return sig.State == "locked" } - unlocked := func(sig *StringSignal) bool { - return sig.Str == "unlocked" + unlocked := func(sig *LockSignal) bool { + return sig.State == "unlocked" } _, err := LockLockable(ctx, l0, l5.ID) fatalErr(t, err) - _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, locked) fatalErr(t, err) id, err := LockLockable(ctx, l1, l1.ID) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { - return sig.Error == "not_unlocked" && sig.ReqID() == id + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool { + return sig.Error == "not_unlocked" && sig.Header().ReqID == id }) fatalErr(t, err) _, err = UnlockLockable(ctx, l0, l5.ID) fatalErr(t, err) - _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked) + _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, unlocked) fatalErr(t, err) _, err = LockLockable(ctx, l1, l1.ID) fatalErr(t, err) for i := 0; i < 4; i++ { - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool { - return sig.Str == "locked" + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *LockSignal) bool { + return sig.State == "locked" }) fatalErr(t, err) } diff --git a/node.go b/node.go index dc4b780..3f0f1a7 100644 --- a/node.go +++ b/node.go @@ -29,22 +29,20 @@ const ( var ( // Base NodeID, used as a special value ZeroUUID = uuid.UUID{} - ZeroID = NodeID{ZeroUUID} + ZeroID = NodeID(ZeroUUID) ) // A NodeID uniquely identifies a Node -type NodeID struct { - uuid.UUID +type NodeID uuid.UUID +func (id NodeID) MarshalBinary() ([]byte, error) { + return (uuid.UUID)(id).MarshalBinary() } - -func (id NodeID) Serialize() []byte { - ser, _ := id.MarshalBinary() - return ser +func (id NodeID) String() string { + return (uuid.UUID)(id).String() } - func IDFromBytes(bytes []byte) (NodeID, error) { - id, err := uuid.FromBytes(bytes[:]) - return NodeID{id}, err + id, err := uuid.FromBytes(bytes) + return NodeID(id), err } // Parse an ID from a string @@ -53,26 +51,17 @@ func ParseID(str string) (NodeID, error) { if err != nil { return NodeID{}, err } - return NodeID{id_uuid}, nil + return NodeID(id_uuid), nil } // Generate a random NodeID func RandID() NodeID { - return NodeID{uuid.New()} -} - -// A Serializable has a type that can be used to map to it, and a function to serialize` the current state -type Serializable[I comparable] interface { - Serialize()([]byte,error) - Deserialize(*Context,[]byte)error - Type() I + return NodeID(uuid.New()) } // Extensions are data attached to nodes that process signals type Extension interface { - Serializable[ExtType] - Field(string)interface{} - Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages + Process(*Context, *Node, NodeID, Signal) Messages } // A QueuedSignal is a Signal that has been Queued to trigger at a set time @@ -130,10 +119,10 @@ const ( Pending ) -func (node *Node) Allows(principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) { +func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) { pends := map[PolicyType]Messages{} for policy_type, policy := range(node.Policies) { - msgs, resp := policy.Allows(principal_id, action, node) + msgs, resp := policy.Allows(ctx, principal_id, action, node) if resp == Allow { return nil, Allow } else if resp == Pending { @@ -154,7 +143,7 @@ func (node *Node) QueueSignal(time time.Time, signal Signal) { func (node *Node) DequeueSignal(id uuid.UUID) error { idx := -1 for i, q := range(node.SignalQueue) { - if q.Signal.ID() == id { + if q.Signal.Header().ID == id { idx = i break } @@ -202,16 +191,43 @@ func runNode(ctx *Context, node *Node) { ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) } -func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} { - exts := map[ExtType]map[string]interface{}{} +type StringError string +func (err StringError) String() string { + return string(err) +} +func (err StringError) Error() string { + return err.String() +} +func (err StringError) MarshalBinary() ([]byte, error) { + return []byte(string(err)), nil +} +func NewErrorField(fstring string, args ...interface{}) SerializedValue { + str := StringError(fmt.Sprintf(fstring, args...)) + str_ser, err := str.MarshalBinary() + if err != nil { + panic(err) + } + return SerializedValue{ + TypeStack: []uint64{uint64(ErrorType)}, + Data: str_ser, + } +} + +func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue { + exts := map[ExtType]map[string]SerializedValue{} for ext_type, field_reqs := range(reqs) { - fields := map[string]interface{}{} + fields := map[string]SerializedValue{} for _, req := range(field_reqs) { ext, exists := node.Extensions[ext_type] if exists == false { - fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type) + fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type) } else { - fields[req] = ext.Field(req) + f, err := SerializeField(ctx, ext, req) + if err != nil { + fields[req] = NewErrorField(err.Error()) + } else { + fields[req] = f + } } } exts[ext_type] = fields @@ -227,21 +243,40 @@ func nodeLoop(ctx *Context, node *Node) error { } // Perform startup actions - node.Process(ctx, ZeroID, &StartSignal) - - for true { + node.Process(ctx, ZeroID, NewStartSignal()) + run := true + for run == true { var signal Signal var source NodeID select { case msg := <- node.MsgChan: ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal) - ser, err := msg.Signal.Serialize() + signal_type, exists := ctx.SignalTypes[reflect.TypeOf(msg.Signal).Elem()] + if exists == false { + ctx.Log.Logf("signal", "SIGNAL_NOT_REGISTERED: %+v", reflect.TypeOf(msg.Signal).Elem()) + } + + signal_ser, err := SerializeSignal(ctx, signal, signal_type) if err != nil { - ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal) + ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal) + } + ser, err := signal_ser.MarshalBinary() + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser) continue } - sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...) + dst_id_ser, err := msg.Dest.MarshalBinary() + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err) + continue + } + src_id_ser, err := msg.Source.MarshalBinary() + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err) + continue + } + sig_data := append(dst_id_ser, src_id_ser...) sig_data = append(sig_data, ser...) validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) if validated == false { @@ -251,26 +286,26 @@ func nodeLoop(ctx *Context, node *Node) error { princ_id := KeyID(msg.Principal) if princ_id != node.ID { - pends, resp := node.Allows(princ_id, msg.Signal.Permission()) + pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission()) if resp == Deny { ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.Header().ID, "acl denied"), msg.Source) ctx.Send(msgs) continue } else if resp == Pending { ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends) - timeout_signal := NewACLTimeoutSignal(msg.Signal.ID()) + timeout_signal := NewACLTimeoutSignal(msg.Signal.Header().ID) node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal) msgs := Messages{} for policy_type, sigs := range(pends) { for _, m := range(sigs) { msgs = append(msgs, m) - node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} + node.PendingSignals[m.Signal.Header().ID] = PendingSignal{policy_type, false, msg.Signal.Header().ID} } } - node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} + node.PendingACLs[msg.Signal.Header().ID] = PendingACL{len(msgs), timeout_signal.ID, msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} ctx.Send(msgs) continue } else if resp == Allow { @@ -290,7 +325,7 @@ func nodeLoop(ctx *Context, node *Node) error { t := node.NextSignal.Time i := -1 for j, queued := range(node.SignalQueue) { - if queued.Signal.ID() == node.NextSignal.Signal.ID() { + if queued.Signal.Header().ID == node.NextSignal.Signal.Header().ID { i = j break } @@ -304,26 +339,26 @@ func nodeLoop(ctx *Context, node *Node) error { node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) if node.NextSignal == nil { - ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan) + ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, signal, t, node.TimeoutChan) } else { - ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time) + ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time) } } ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) - info, waiting := node.PendingSignals[signal.ReqID()] + info, waiting := node.PendingSignals[signal.Header().ReqID] if waiting == true { if info.Found == false { info.Found = true - node.PendingSignals[signal.ReqID()] = info + node.PendingSignals[signal.Header().ReqID] = info ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) req_info, exists := node.PendingACLs[info.ID] if exists == true { req_info.Counter -= 1 req_info.Responses = append(req_info.Responses, signal) - allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal) + allowed := node.Policies[info.Policy].ContinueAllows(ctx, req_info, signal) if allowed == Allow { ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) signal = req_info.Signal @@ -337,7 +372,7 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) // Send the denied response msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source) err := ctx.Send(msgs) if err != nil { ctx.Log.Logf("signal", "SEND_ERR: %s", err) @@ -355,24 +390,19 @@ func nodeLoop(ctx *Context, node *Node) error { } } - // Handle node signals - if signal.Type() == StopSignalType { + switch sig := signal.(type) { + case *StopSignal: msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "stopped"), source) + ctx.Send(msgs) + node.Process(ctx, node.ID, NewStatusSignal(node.ID, "stopped")) + run = false + case *ReadSignal: + result := node.ReadFields(ctx, sig.Extensions) + msgs := Messages{} + msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "read_done"), source) ctx.Send(msgs) - node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) - break - } else if signal.Type() == ReadSignalType { - read_signal, ok := signal.(*ReadSignal) - if ok == false { - ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) - } else { - result := node.ReadFields(read_signal.Extensions) - msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source) - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source) - ctx.Send(msgs) - } } node.Process(ctx, source, signal) @@ -401,8 +431,8 @@ type Message struct { } type Messages []*Message -func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { - msg, err := NewMessage(dest, source, principal, signal) +func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { + msg, err := NewMessage(ctx, dest, source, principal, signal) if err != nil { panic(err) } else { @@ -411,13 +441,31 @@ func (msgs Messages) Add(source NodeID, principal ed25519.PrivateKey, signal Sig return msgs } -func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { - ser, err := signal.Serialize() +func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { + signal_type, exists := ctx.SignalTypes[reflect.TypeOf(signal)] + if exists == false { + return nil, fmt.Errorf("Cannot put %+v in a message, not a known signal type", reflect.TypeOf(signal)) + } + + signal_ser, err := SerializeSignal(ctx, signal, signal_type) + if err != nil { + return nil, err + } + + ser, err := signal_ser.MarshalBinary() if err != nil { return nil, err } - sig_data := append(dest.Serialize(), source.Serialize()...) + dest_ser, err := dest.MarshalBinary() + if err != nil { + return nil, err + } + source_ser, err := source.MarshalBinary() + if err != nil { + return nil, err + } + sig_data := append(dest_ser, source_ser...) sig_data = append(sig_data, ser...) sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) @@ -435,7 +483,7 @@ func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal } func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { - ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type()) + ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal) messages := Messages{} for ext_type, ext := range(node.Extensions) { ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) @@ -449,120 +497,86 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { return ctx.Send(messages) } -func GetCtx[T Extension, C any](ctx *Context) (C, error) { - var zero T +func GetCtx[T Extension, C any](ctx *Context, ext_type ExtType) (C, error) { var zero_ctx C - ext_type := zero.Type() - type_hash := Hash(ext_type) - ext_info, ok := ctx.Extensions[type_hash] + ext_info, ok := ctx.Extensions[ext_type] if ok == false { - return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type) + return zero_ctx, fmt.Errorf("%+v is not an extension in ctx", ext_type) } ext_ctx, ok := ext_info.Data.(C) if ok == false { - return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero)) + return zero_ctx, fmt.Errorf("context for %+v is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero_ctx)) } return ext_ctx, nil } -func GetExt[T Extension](node *Node) (T, error) { +func GetExt[T Extension](node *Node, ext_type ExtType) (T, error) { var zero T - ext_type := zero.Type() ext, exists := node.Extensions[ext_type] if exists == false { - return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions) + return zero, fmt.Errorf("%+v does not have %+v extension - %+v", node.ID, ext_type, node.Extensions) } ret, ok := ext.(T) if ok == false { - return zero, fmt.Errorf("%s in %s is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero)) + return zero, fmt.Errorf("%+v in %+v is wrong type(%+v), expecting %+v", ext_type, node.ID, reflect.TypeOf(ext), reflect.TypeOf(zero)) } return ret, nil } -func (node *Node) Serialize() ([]byte, error) { - extensions := make([]ExtensionDB, len(node.Extensions)) - qsignals := make([]QSignalDB, len(node.SignalQueue)) - policies := make([]PolicyDB, len(node.Policies)) +func (node *Node) Serialize(ctx *Context) (SerializedValue, error) { + if node == nil { + return SerializedValue{}, fmt.Errorf("Cannot serialize nil Node") + } + + node_bytes := make([]byte, 8 * 3) + binary.BigEndian.PutUint64(node_bytes[0:8], uint64(len(node.Extensions))) + binary.BigEndian.PutUint64(node_bytes[8:16], uint64(len(node.Policies))) + binary.BigEndian.PutUint64(node_bytes[16:24], uint64(len(node.SignalQueue))) key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) if err != nil { - return nil, err + return SerializedValue{}, err } - node_db := NodeDB{ - Header: NodeDBHeader{ - Magic: NODE_DB_MAGIC, - TypeHash: Hash(node.Type), - KeyLength: uint32(len(key_bytes)), - BufferSize: node.BufferSize, - NumExtensions: uint32(len(extensions)), - NumPolicies: uint32(len(policies)), - NumQueuedSignals: uint32(len(node.SignalQueue)), - }, - Extensions: extensions, - Policies: policies, - QueuedSignals: qsignals, - KeyBytes: key_bytes, + key_val := SerializedValue{ + TypeStack: []uint64{uint64(NodeKeyType)}, + Data: key_bytes, + } + key_ser, err := key_val.MarshalBinary() + if err != nil { + return SerializedValue{}, err } + node_bytes = append(node_bytes, key_ser...) - i := 0 - for ext_type, info := range(node.Extensions) { - ser, err := info.Serialize() + for ext_type, ext := range(node.Extensions) { + ctx.Log.Logf("serialize", "SERIALIZING_EXTENSION: %+v", ext) + ext_ser, err := SerializeExtension(ctx, ext, ext_type) if err != nil { - return nil, err - } - node_db.Extensions[i] = ExtensionDB{ - Header: ExtensionDBHeader{ - TypeHash: Hash(ext_type), - Length: uint64(len(ser)), - }, - Data: ser, + return SerializedValue{}, err } - i += 1 - } - - for i, qsignal := range(node.SignalQueue) { - ser, err := qsignal.Signal.Serialize() + ext_bytes, err := ext_ser.MarshalBinary() if err != nil { - return nil, err + return SerializedValue{}, err } - node_db.QueuedSignals[i] = QSignalDB{ - QSignalDBHeader{ - qsignal.Time, - Hash(qsignal.Signal.Type()), - uint64(len(ser)), - }, - ser, - } + node_bytes = append(node_bytes, ext_bytes...) } - i = 0 - for _, policy := range(node.Policies) { - ser, err := policy.Serialize() - if err != nil { - return nil, err - } - - node_db.Policies[i] = PolicyDB{ - PolicyDBHeader{ - Hash(policy.Type()), - uint64(len(ser)), - }, - ser, - } + node_value := SerializedValue{ + TypeStack: []uint64{uint64(node.Type)}, + Data: node_bytes, } - return node_db.Serialize(), nil + return node_value, nil } func KeyID(pub ed25519.PublicKey) NodeID { - str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) - return NodeID{str} + id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) + return NodeID(id) } // Create a new node in memory and start it's event loop @@ -584,24 +598,28 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si panic("Attempted to create an existing node") } - def, exists := ctx.Types[Hash(node_type)] + def, exists := ctx.Nodes[node_type] if exists == false { panic("Node type %s not registered in Context") } ext_map := map[ExtType]Extension{} for _, ext := range(extensions) { - _, exists := ext_map[ext.Type()] + ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext)] + if exists == false { + panic(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext))) + } + _, exists = ext_map[ext_type] if exists == true { panic("Cannot add the same extension to a node twice") } - ext_map[ext.Type()] = ext + ext_map[ext_type] = ext } for _, required_ext := range(def.Extensions) { _, exists := ext_map[required_ext] if exists == false { - panic(fmt.Sprintf("%s requires %s", node_type, required_ext)) + panic(fmt.Sprintf("%+v requires %+v", node_type, required_ext)) } } @@ -610,9 +628,9 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si } default_policy := NewAllNodesPolicy(Tree{ - ErrorSignalType.String(): nil, - ReadResultSignalType.String(): nil, - StatusSignalType.String(): nil, + uint64(ErrorSignalType): nil, + uint64(ReadResultSignalType): nil, + uint64(StatusSignalType): nil, }) all_nodes_policy, exists := policies[AllNodesPolicyType] @@ -642,252 +660,32 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si panic(err) } - node.Process(ctx, ZeroID, &NewSignal) + node.Process(ctx, ZeroID, NewCreateSignal()) go runNode(ctx, node) return node } -type PolicyDBHeader struct { - TypeHash uint64 - Length uint64 -} - -type PolicyDB struct { - Header PolicyDBHeader - Data []byte -} - -type QSignalDBHeader struct { - Time time.Time - TypeHash uint64 - Length uint64 -} - -type QSignalDB struct { - Header QSignalDBHeader - Data []byte -} - -type ExtensionDBHeader struct { - TypeHash uint64 - Length uint64 -} - -type ExtensionDB struct { - Header ExtensionDBHeader - Data []byte -} - -// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node -type NodeDBHeader struct { - Magic uint32 - NumExtensions uint32 - NumPolicies uint32 - NumQueuedSignals uint32 - BufferSize uint32 - KeyLength uint32 - TypeHash uint64 -} - -type NodeDB struct { - Header NodeDBHeader - Extensions []ExtensionDB - Policies []PolicyDB - QueuedSignals []QSignalDB - KeyBytes []byte -} - -//TODO: add size safety checks -func NewNodeDB(data []byte) (NodeDB, error) { - var zero NodeDB - - ptr := 0 - - magic := binary.BigEndian.Uint32(data[0:4]) - num_extensions := binary.BigEndian.Uint32(data[4:8]) - num_policies := binary.BigEndian.Uint32(data[8:12]) - num_queued_signals := binary.BigEndian.Uint32(data[12:16]) - buffer_size := binary.BigEndian.Uint32(data[16:20]) - key_length := binary.BigEndian.Uint32(data[20:24]) - node_type_hash := binary.BigEndian.Uint64(data[24:32]) - - ptr += NODE_DB_HEADER_LEN - - if magic != NODE_DB_MAGIC { - return zero, fmt.Errorf("header has incorrect magic 0x%x", magic) - } - - key_bytes := make([]byte, key_length) - n := copy(key_bytes, data[ptr:(ptr+int(key_length))]) - if n != int(key_length) { - return zero, fmt.Errorf("not enough key bytes: %d", n) - } - - ptr += int(key_length) - - extensions := make([]ExtensionDB, num_extensions) - for i, _ := range(extensions) { - cur := data[ptr:] - - type_hash := binary.BigEndian.Uint64(cur[0:8]) - length := binary.BigEndian.Uint64(cur[8:16]) - - data_start := uint64(EXTENSION_DB_HEADER_LEN) - data_end := data_start + length - ext_data := cur[data_start:data_end] - - extensions[i] = ExtensionDB{ - Header: ExtensionDBHeader{ - TypeHash: type_hash, - Length: length, - }, - Data: ext_data, - } - - ptr += int(EXTENSION_DB_HEADER_LEN + length) - } - - policies := make([]PolicyDB, num_policies) - for i, _ := range(policies) { - cur := data[ptr:] - type_hash := binary.BigEndian.Uint64(cur[0:8]) - length := binary.BigEndian.Uint64(cur[8:16]) - - data_start := uint64(POLICY_DB_HEADER_LEN) - data_end := data_start + length - policy_data := cur[data_start:data_end] - - policies[i] = PolicyDB{ - PolicyDBHeader{ - type_hash, - length, - }, - policy_data, - } - ptr += int(POLICY_DB_HEADER_LEN + length) - } - - queued_signals := make([]QSignalDB, num_queued_signals) - for i, _ := range(queued_signals) { - cur := data[ptr:] - // TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented - // Right now causes segfault because any saved signal is loaded as nil - unix_milli := binary.BigEndian.Uint64(cur[0:8]) - type_hash := binary.BigEndian.Uint64(cur[8:16]) - signal_size := binary.BigEndian.Uint64(cur[16:24]) - - signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)] - - queued_signals[i] = QSignalDB{ - QSignalDBHeader{ - time.UnixMilli(int64(unix_milli)), - type_hash, - signal_size, - }, - signal_data, - } - - ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size) - } - - return NodeDB{ - Header: NodeDBHeader{ - Magic: magic, - TypeHash: node_type_hash, - BufferSize: buffer_size, - KeyLength: key_length, - NumExtensions: num_extensions, - NumQueuedSignals: num_queued_signals, - }, - KeyBytes: key_bytes, - Extensions: extensions, - QueuedSignals: queued_signals, - }, nil -} - -func (header NodeDBHeader) Serialize() []byte { - if header.Magic != NODE_DB_MAGIC { - panic(fmt.Sprintf("Serializing header with invalid magic %0x", header.Magic)) - } - - ret := make([]byte, NODE_DB_HEADER_LEN) - binary.BigEndian.PutUint32(ret[0:4], header.Magic) - binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) - binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies) - binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals) - binary.BigEndian.PutUint32(ret[16:20], header.BufferSize) - binary.BigEndian.PutUint32(ret[20:24], header.KeyLength) - binary.BigEndian.PutUint64(ret[24:32], header.TypeHash) - return ret -} - -func (node NodeDB) Serialize() []byte { - ser := node.Header.Serialize() - ser = append(ser, node.KeyBytes...) - for _, extension := range(node.Extensions) { - ser = append(ser, extension.Serialize()...) - } - for _, policy := range(node.Policies) { - ser = append(ser, policy.Serialize()...) - } - for _, qsignal := range(node.QueuedSignals) { - ser = append(ser, qsignal.Serialize()...) - } - - return ser -} - -func (header QSignalDBHeader) Serialize() []byte { - ret := make([]byte, QSIGNAL_DB_HEADER_LEN) - binary.BigEndian.PutUint64(ret[0:8], uint64(header.Time.UnixMilli())) - binary.BigEndian.PutUint64(ret[8:16], header.TypeHash) - binary.BigEndian.PutUint64(ret[16:24], header.Length) - return ret -} - -func (qsignal QSignalDB) Serialize() []byte { - header_bytes := qsignal.Header.Serialize() - return append(header_bytes, qsignal.Data...) -} - -func (header ExtensionDBHeader) Serialize() []byte { - ret := make([]byte, EXTENSION_DB_HEADER_LEN) - binary.BigEndian.PutUint64(ret[0:8], header.TypeHash) - binary.BigEndian.PutUint64(ret[8:16], header.Length) - return ret -} - -func (extension ExtensionDB) Serialize() []byte { - header_bytes := extension.Header.Serialize() - return append(header_bytes, extension.Data...) -} - -func (header PolicyDBHeader) Serialize() []byte { - ret := make([]byte, POLICY_DB_HEADER_LEN) - binary.BigEndian.PutUint64(ret[0:8], header.TypeHash) - binary.BigEndian.PutUint64(ret[0:8], header.Length) - return ret -} - -func (policy PolicyDB) Serialize() []byte { - header_bytes := policy.Header.Serialize() - return append(header_bytes, policy.Data...) -} - // Write a node to the database func WriteNode(ctx *Context, node *Node) error { ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) - bytes, err := node.Serialize() + node_serialized, err := node.Serialize(ctx) + if err != nil { + return err + } + bytes, err := node_serialized.MarshalBinary() if err != nil { return err } ctx.Log.Logf("db_data", "DB_DATA: %+v", bytes) - id_bytes := node.ID.Serialize() + id_bytes, err := node.ID.MarshalBinary() + if err != nil { + return err + } ctx.Log.Logf("db", "DB_WRITE_ID: %+v", id_bytes) return ctx.DB.Update(func(txn *badger.Txn) error { @@ -895,11 +693,15 @@ func WriteNode(ctx *Context, node *Node) error { }) } +//TODO: fix after capnp func LoadNode(ctx * Context, id NodeID) (*Node, error) { ctx.Log.Logf("db", "LOADING_NODE: %s", id) var bytes []byte err := ctx.DB.View(func(txn *badger.Txn) error { - id_bytes := id.Serialize() + id_bytes, err := id.MarshalBinary() + if err != nil { + return err + } ctx.Log.Logf("db", "DB_READ_ID: %+v", id_bytes) item, err := txn.Get(id_bytes) if err != nil { @@ -917,137 +719,18 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } - // Parse the bytes from the DB - node_db, err := NewNodeDB(bytes) - if err != nil { - return nil, err - } - - policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies) - for _, policy_db := range(node_db.Policies) { - policy_info, exists := ctx.Policies[policy_db.Header.TypeHash] - if exists == false { - return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash) - } + num_extensions := binary.BigEndian.Uint64(bytes[0:8]) + num_policies := binary.BigEndian.Uint64(bytes[8:16]) + num_signals := binary.BigEndian.Uint64(bytes[16:24]) + print(num_extensions) + print(num_policies) + print(num_signals) - policy, err := policy_info.Load(ctx, policy_db.Data) - if err != nil { - return nil, err - } - - policies[policy_info.Type] = policy - } - - key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes) - if err != nil { - return nil, err - } - - var key ed25519.PrivateKey - switch k := key_raw.(type) { - case ed25519.PrivateKey: - key = k - default: - return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k)) - } - - key_id := KeyID(key.Public().(ed25519.PublicKey)) - if key_id != id { - return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id) - } - - node_type, known := ctx.Types[node_db.Header.TypeHash] - if known == false { - return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) - } - - signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals) - for i, qsignal := range(node_db.QueuedSignals) { - sig_info, exists := ctx.Signals[qsignal.Header.TypeHash] - if exists == false { - return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash) - } - - signal, err := sig_info.Load(ctx, qsignal.Data) - if err != nil { - return nil, err - } - - signal_queue[i] = QueuedSignal{signal, qsignal.Header.Time} - } - - next_signal, timeout_chan := SoonestSignal(signal_queue) - node := &Node{ - Key: key, - ID: key_id, - Type: node_type.Type, - Extensions: map[ExtType]Extension{}, - Policies: policies, - MsgChan: make(chan *Message, node_db.Header.BufferSize), - BufferSize: node_db.Header.BufferSize, - TimeoutChan: timeout_chan, - SignalQueue: signal_queue, - NextSignal: next_signal, - } + /* ctx.AddNode(id, node) - - found_extensions := []ExtType{} - // Parse each of the extensions from the db - for _, ext_db := range(node_db.Extensions) { - type_hash := ext_db.Header.TypeHash - def, known := ctx.Extensions[type_hash] - if known == false { - return nil, fmt.Errorf("%s tried to load extension 0x%x, which is not a known extension type", id, type_hash) - } - ctx.Log.Logf("db", "DB_EXTENSION_LOADING: %s/%s", id, def.Type) - extension, err := def.Load(ctx, ext_db.Data) - if err != nil { - return nil, err - } - node.Extensions[def.Type] = extension - found_extensions = append(found_extensions, def.Type) - ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s/%s - %+v", id, def.Type, extension) - } - - missing_extensions := []ExtType{} - for _, ext := range(node_type.Extensions) { - found := false - for _, found_ext := range(found_extensions) { - if found_ext == ext { - found = true - break - } - } - if found == false { - missing_extensions = append(missing_extensions, ext) - } - } - - if len(missing_extensions) > 0 { - return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions) - } - - extra_extensions := []ExtType{} - for _, found_ext := range(found_extensions) { - found := false - for _, ext := range(node_type.Extensions) { - if ext == found_ext { - found = true - break - } - } - if found == false { - extra_extensions = append(extra_extensions, found_ext) - } - } - - if len(extra_extensions) > 0 { - ctx.Log.Logf("db", "DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions) - } - ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) - go runNode(ctx, node) + */ - return node, nil + return nil, nil } diff --git a/node_test.go b/node_test.go index 6525884..6603c56 100644 --- a/node_test.go +++ b/node_test.go @@ -8,21 +8,21 @@ import ( ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{"signal", "node", "db"}) - node_type := NodeType("test") + ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize"}) + node_type := NewNodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) - node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil)) + node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil)) - ctx.Nodes = map[NodeID]*Node{} - _, err = ctx.GetNode(node.ID) + ctx.nodeMap = map[NodeID]*Node{} + _, err = ctx.getNode(node.ID) fatalErr(t, err) } func TestNodeRead(t *testing.T) { ctx := logTestContext(t, []string{"test"}) - node_type := NodeType("TEST") + node_type := NewNodeType("TEST") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType}) fatalErr(t, err) @@ -38,27 +38,27 @@ func TestNodeRead(t *testing.T) { ctx.Log.Logf("test", "N2: %s", n2_id) n1_policy := NewPerNodePolicy(map[NodeID]Tree{ - n2_id: Tree{ - ReadSignalType.String(): nil, + n2_id: { + uint64(ReadSignalType): nil, }, }) n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), n2_listener) n1 := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{ PerNodePolicyType: &n1_policy, - }, NewGroupExt(nil), NewECDHExt()) + }, NewGroupExt(nil)) read_sig := NewReadSignal(map[ExtType][]string{ - GroupExtType: []string{"members"}, + GroupExtType: {"members"}, }) msgs := Messages{} - msgs = msgs.Add(n2.ID, n2.Key, read_sig, n1.ID) + msgs = msgs.Add(ctx, n2.ID, n2.Key, read_sig, n1.ID) err = ctx.Send(msgs) fatalErr(t, err) - res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { + res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool { return true }) fatalErr(t, err) diff --git a/policy.go b/policy.go index 59a74a2..9bbdf79 100644 --- a/policy.go +++ b/policy.go @@ -1,35 +1,26 @@ package graphvent import ( - "encoding/json" -) - -const ( - MemberOfPolicyType = PolicyType("USER_OF") - RequirementOfPolicyType = PolicyType("REQUIEMENT_OF") - PerNodePolicyType = PolicyType("PER_NODE") - AllNodesPolicyType = PolicyType("ALL_NODES") ) type Policy interface { - Serializable[PolicyType] - Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) - ContinueAllows(current PendingACL, signal Signal)RuleResult + Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) + ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult // Merge with another policy of the same underlying type Merge(Policy) Policy // Make a copy of this policy Copy() Policy } -func (policy *AllNodesPolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { +func (policy *AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { return nil, policy.Rules.Allows(action) } -func (policy *AllNodesPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { +func (policy *AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { return Deny } -func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { +func (policy *PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { for id, actions := range(policy.NodeRules) { if id != principal_id { continue @@ -39,7 +30,7 @@ func (policy *PerNodePolicy) Allows(principal_id NodeID, action Tree, node *Node return nil, Deny } -func (policy *PerNodePolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { +func (policy *PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { return Deny } @@ -57,7 +48,7 @@ func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy { } } -func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { +func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { sig, ok := signal.(*ReadResultSignal) if ok == false { return Deny @@ -68,7 +59,17 @@ func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Sig return Deny } - requirements, ok := ext["requirements"].(map[NodeID]string) + reqs_ser, ok := ext["requirements"] + if ok == false { + return Deny + } + + reqs_if, err := DeserializeValue(ctx, reqs_ser) + if err != nil { + return Deny + } + + requirements, ok := reqs_if.(map[NodeID]ReqState) if ok == false { return Deny } @@ -96,7 +97,7 @@ func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy { } } -func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { +func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { sig, ok := signal.(*ReadResultSignal) if ok == false { return Deny @@ -107,7 +108,17 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) return Deny } - members, ok := group_ext_data["members"].(map[NodeID]string) + members_ser, ok := group_ext_data["members"] + if ok == false { + return Deny + } + + members_if, err := DeserializeValue(ctx, members_ser) + if err != nil { + return Deny + } + + members, ok := members_if.(map[NodeID]string) if ok == false { return Deny } @@ -122,11 +133,11 @@ func (policy *MemberOfPolicy) ContinueAllows(current PendingACL, signal Signal) } // Send a read signal to Group to check if principal_id is a member of it -func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { +func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { msgs := Messages{} for id, rule := range(policy.NodeRules) { if id == node.ID { - ext, err := GetExt[*GroupExt](node) + ext, err := GetExt[*GroupExt](node, GroupExtType) if err == nil { for member, _ := range(ext.Members) { if member == principal_id { @@ -137,7 +148,7 @@ func (policy *MemberOfPolicy) Allows(principal_id NodeID, action Tree, node *Nod } } } else { - msgs = msgs.Add(node.ID, node.Key, NewReadSignal(map[ExtType][]string{ + msgs = msgs.Add(ctx, node.ID, node.Key, NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, }), id) } @@ -238,22 +249,22 @@ func (policy *AllNodesPolicy) Copy() Policy { } } -type Tree map[string]Tree +type Tree map[uint64]Tree func (rule Tree) Allows(action Tree) RuleResult { // If the current rule is nil, it's a wildcard and any action being processed is allowed if rule == nil { return Allow - // If the rule isn't "allow all" but the action is "request all", deny + // If the rule isn't "allow all" but the action is "request all", deny } else if action == nil { return Deny - // If the current action has no children, it's allowed + // If the current action has no children, it's allowed } else if len(action) == 0 { return Allow - // If the current rule has no children but the action goes further, it's not allowed + // If the current rule has no children but the action goes further, it's not allowed } else if len(rule) == 0 { return Deny - // If the current rule and action have children, all the children of action must be allowed by rule + // If the current rule and action have children, all the children of action must be allowed by rule } else { for sub, subtree := range(action) { subrule, exists := rule[sub] @@ -285,14 +296,6 @@ func (policy *PerNodePolicy) Type() PolicyType { return PerNodePolicyType } -func (policy *PerNodePolicy) Serialize() ([]byte, error) { - return json.MarshalIndent(policy, "", " ") -} - -func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error { - return json.Unmarshal(data, policy) -} - func NewAllNodesPolicy(rules Tree) AllNodesPolicy { return AllNodesPolicy{ Rules: rules, @@ -307,15 +310,7 @@ func (policy *AllNodesPolicy) Type() PolicyType { return AllNodesPolicyType } -func (policy *AllNodesPolicy) Serialize() ([]byte, error) { - return json.MarshalIndent(policy, "", " ") -} - -func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error { - return json.Unmarshal(data, policy) -} - var DefaultPolicy = NewAllNodesPolicy(Tree{ - ErrorSignalType.String(): nil, - ReadResultSignalType.String(): nil, + uint64(ErrorSignalType): nil, + uint64(ReadResultSignalType): nil, }) diff --git a/signal.go b/signal.go index 3508fe4..c4a266e 100644 --- a/signal.go +++ b/signal.go @@ -1,49 +1,29 @@ package graphvent import ( - "time" "fmt" - "encoding/json" - "encoding/binary" - "crypto" - "crypto/ed25519" - "crypto/ecdh" - "crypto/rand" - "crypto/aes" - "crypto/cipher" + "time" + + "capnproto.org/go/capnp/v3" "github.com/google/uuid" + schema "github.com/mekkanized/graphvent/signal" ) type SignalDirection int const ( - StopSignalType = SignalType("STOP") - NewSignalType = SignalType("NEW") - StartSignalType = SignalType("START") - ErrorSignalType = SignalType("ERROR") - StatusSignalType = SignalType("STATUS") - LinkSignalType = SignalType("LINK") - LockSignalType = SignalType("LOCK") - ReadSignalType = SignalType("READ") - ReadResultSignalType = SignalType("READ_RESULT") - ECDHSignalType = SignalType("ECDH") - ECDHProxySignalType = SignalType("ECDH_PROXY") - ACLTimeoutSignalType = SignalType("ACL_TIMEOUT") - Up SignalDirection = iota Down Direct ) -type SignalType string -func (signal_type SignalType) String() string { return string(signal_type) } -func (signal_type SignalType) Prefix() string { return "SIGNAL: " } +type SignalHeader struct { + Direction SignalDirection + ID uuid.UUID + ReqID uuid.UUID +} type Signal interface { - Serializable[SignalType] - String() string - Direction() SignalDirection - ID() uuid.UUID - ReqID() uuid.UUID + Header() *SignalHeader Permission() Tree } @@ -59,7 +39,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU if signal == nil { return nil, fmt.Errorf("LISTENER_CLOSED") } - if signal.ReqID() == req_id { + if signal.Header().ReqID == req_id { return signal, nil } case <-timeout_channel: @@ -69,7 +49,7 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU return nil, fmt.Errorf("UNREACHABLE") } -func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { +func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check func(S)bool) (S, error) { var zero S var timeout_channel <- chan time.Time if timeout > 0 { @@ -79,493 +59,396 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal select { case signal := <- listener: if signal == nil { - return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) + return zero, fmt.Errorf("LISTENER_CLOSED") } - if signal.Type() == signal_type { - sig, ok := signal.(S) - if ok == true { - if check(sig) == true { - return sig, nil - } + sig, ok := signal.(S) + if ok == true { + if check(sig) == true { + return sig, nil } } case <-timeout_channel: - return zero, fmt.Errorf("LISTENER_TIMEOUT: %s", signal_type) + return zero, fmt.Errorf("LISTENER_TIMEOUT") } } return zero, fmt.Errorf("LOOP_ENDED") } -type BaseSignal struct { - SignalDirection SignalDirection `json:"direction"` - SignalType SignalType `json:"type"` - UUID uuid.UUID `json:"id"` - ReqUUID uuid.UUID `json:"req_uuid"` -} - -func (signal *BaseSignal) ReqID() uuid.UUID { - return signal.ReqUUID -} - -func (signal *BaseSignal) String() string { - ser, _ := json.Marshal(signal) - return string(ser) -} - -func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error { - return json.Unmarshal(data, signal) -} - -func (signal *BaseSignal) ID() uuid.UUID { - return signal.UUID -} - -func (signal *BaseSignal) Type() SignalType { - return signal.SignalType -} - -func (signal *BaseSignal) Permission() Tree { - return Tree{ - string(signal.Type()): Tree{}, - } -} - -func (signal *BaseSignal) Direction() SignalDirection { - return signal.SignalDirection -} - -func (signal *BaseSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) -} - -func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { +func NewSignalHeader(direction SignalDirection) SignalHeader { id := uuid.New() - signal := BaseSignal{ - UUID: id, - ReqUUID: id, - SignalDirection: direction, - SignalType: signal_type, + header := SignalHeader{ + ID: id, + ReqID: id, + Direction: direction, } - return signal + return header } -func NewRespSignal(id uuid.UUID, signal_type SignalType, direction SignalDirection) BaseSignal { - signal := BaseSignal{ - UUID: uuid.New(), - ReqUUID: id, - SignalDirection: direction, - SignalType: signal_type, +func NewRespHeader(req_id uuid.UUID, direction SignalDirection) SignalHeader { + header := SignalHeader{ + ID: uuid.New(), + ReqID: req_id, + Direction: direction, } - return signal + return header } -var NewSignal = NewBaseSignal(NewSignalType, Direct) -var StartSignal = NewBaseSignal(StartSignalType, Direct) -var StopSignal = NewBaseSignal(StopSignalType, Direct) - -type IDSignal struct { - BaseSignal - NodeID `json:"id"` -} - -func (signal *IDSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) -} - -type StringSignal struct { - BaseSignal - Str string `json:"state"` -} - -func (signal *StringSignal) String() string { - ser, _ := json.Marshal(signal) - return string(ser) -} - -func (signal *StringSignal) Serialize() ([]byte, error) { - return json.Marshal(&signal) -} - -type ErrorSignal struct { - BaseSignal - Error string -} +func SerializeHeader(header SignalHeader, root schema.SignalHeader) error { + root.SetDirection(uint8(header.Direction)) + id_ser, err := header.ID.MarshalBinary() + if err != nil { + return err + } + root.SetId(id_ser) -func (signal *ErrorSignal) String() string { - ser, _ := json.Marshal(signal) - return string(ser) + req_id_ser, err := header.ReqID.MarshalBinary() + if err != nil { + return err + } + root.SetReqID(req_id_ser) + return nil } -func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { - return &ErrorSignal{ - NewRespSignal(req_id, ErrorSignalType, Direct), - fmt.Sprintf(fmt_string, args...), +func DeserializeHeader(header schema.SignalHeader) (SignalHeader, error) { + id_ser, err := header.Id() + if err != nil { + return SignalHeader{}, err + } + id, err := uuid.FromBytes(id_ser) + if err != nil { + return SignalHeader{}, err } -} -func NewACLTimeoutSignal(req_id uuid.UUID) Signal { - sig := NewRespSignal(req_id, ACLTimeoutSignalType, Direct) - return &sig -} + req_id_ser, err := header.ReqID() + if err != nil { + return SignalHeader{}, err + } + req_id, err := uuid.FromBytes(req_id_ser) + if err != nil { + return SignalHeader{}, err + } -type IDStringSignal struct { - BaseSignal - NodeID NodeID `json:"node_id"` - Str string `json:"string"` + return SignalHeader{ + ID: id, + ReqID: req_id, + Direction: SignalDirection(header.Direction()), + }, nil } -func (signal *IDStringSignal) String() string { - ser, _ := json.Marshal(signal) - return string(ser) +type CreateSignal struct { + SignalHeader } -func (signal *IDStringSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) +func (signal *CreateSignal) Header() *SignalHeader { + return &signal.SignalHeader } - -func NewStatusSignal(status string, source NodeID) Signal { - return &IDStringSignal{ - BaseSignal: NewBaseSignal(StatusSignalType, Up), - NodeID: source, - Str: status, +func (signal *CreateSignal) Permission() Tree { + return Tree{ + uint64(CreateSignalType): nil, } } -func NewLinkSignal(state string, id NodeID) Signal { - return &IDStringSignal{ - BaseSignal: NewBaseSignal(LinkSignalType, Direct), - NodeID: id, - Str: state, +func NewCreateSignal() *CreateSignal { + return &CreateSignal{ + NewSignalHeader(Direct), } } -func NewLockSignal(state string) Signal { - return &StringSignal{ - NewBaseSignal(LockSignalType, Direct), - state, - } +type StartSignal struct { + SignalHeader } - -func (signal *StringSignal) Permission() Tree { +func (signal *StartSignal) Header() *SignalHeader { + return &signal.SignalHeader +} +func (signal *StartSignal) Permission() Tree { return Tree{ - string(signal.Type()): Tree{ - signal.Str: Tree{}, - }, + uint64(StartSignalType): nil, } } - -type ReadSignal struct { - BaseSignal - Extensions map[ExtType][]string `json:"extensions"` -} - -func (signal *ReadSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) -} - -func NewReadSignal(exts map[ExtType][]string) *ReadSignal { - return &ReadSignal{ - NewBaseSignal(ReadSignalType, Direct), - exts, +func NewStartSignal() *StartSignal { + return &StartSignal{ + NewSignalHeader(Direct), } } -func (signal *ReadSignal) Permission() Tree { - ret := Tree{} - for ext, fields := range(signal.Extensions) { - field_tree := Tree{} - for _, field := range(fields) { - field_tree[field] = Tree{} - } - ret[ext.String()] = field_tree - } - return Tree{ReadSignalType.String(): ret} +type StopSignal struct { + SignalHeader } - -type ReadResultSignal struct { - BaseSignal - NodeID NodeID - NodeType NodeType - Extensions map[ExtType]map[string]interface{} `json:"extensions"` +func (signal *StopSignal) Header() *SignalHeader { + return &signal.SignalHeader } - -func (signal *ReadResultSignal) Permission() Tree { +func (signal *StopSignal) Permission() Tree { return Tree{ - ReadResultSignalType.String(): Tree{}, + uint64(StopSignalType): nil, } } - -func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal { - return &ReadResultSignal{ - NewRespSignal(req_id, ReadResultSignalType, Direct), - node_id, - node_type, - exts, +func NewStopSignal() *StopSignal { + return &StopSignal{ + NewSignalHeader(Direct), } } -type ECDHSignal struct { - StringSignal - Time time.Time - EDDSA ed25519.PublicKey - ECDH *ecdh.PublicKey - Signature []byte -} - -type ECDHSignalJSON struct { - StringSignal - Time time.Time `json:"time"` - EDDSA []byte `json:"ecdsa_pubkey"` - ECDH []byte `json:"ecdh_pubkey"` - Signature []byte `json:"signature"` +type ErrorSignal struct { + SignalHeader + Error string } - -func (signal *ECDHSignal) MarshalJSON() ([]byte, error) { - return json.Marshal(&ECDHSignalJSON{ - StringSignal: signal.StringSignal, - Time: signal.Time, - ECDH: signal.ECDH.Bytes(), - EDDSA: signal.ECDH.Bytes(), - Signature: signal.Signature, - }) +func (signal *ErrorSignal) Header() *SignalHeader { + return &signal.SignalHeader } - -func (signal *ECDHSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) -} - -func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) { - ec_key, err := ECDH.GenerateKey(rand.Reader) +func (signal *ErrorSignal) MarshalBinary() ([]byte, error) { + arena := capnp.SingleSegment(nil) + msg, seg, err := capnp.NewMessage(arena) if err != nil { - return nil, nil, err + return nil, err } - now := time.Now() - time_bytes, err := now.MarshalJSON() + root, err := schema.NewRootErrorSignal(seg) if err != nil { - return nil, nil, err + return nil, err } - sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) + root.SetError(signal.Error) - sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) + header, err := root.NewHeader() if err != nil { - return nil, nil, err + return nil, err } - return &ECDHSignal{ - StringSignal: StringSignal{ - BaseSignal: NewBaseSignal(ECDHSignalType, Direct), - Str: "req", - }, - Time: now, - EDDSA: node.Key.Public().(ed25519.PublicKey), - ECDH: ec_key.PublicKey(), - Signature: sig, - }, ec_key, nil -} - -const DEFAULT_ECDH_WINDOW = time.Second - -func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { - now := time.Now() - - err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) + err = SerializeHeader(signal.SignalHeader, header) if err != nil { - return ECDHSignal{}, nil, err + return nil, err } - ec_key, err := ECDH.GenerateKey(rand.Reader) + return msg.Marshal() +} +func (signal *ErrorSignal) Deserialize(ctx *Context, data []byte) error { + msg, err := capnp.Unmarshal(data) if err != nil { - return ECDHSignal{}, nil, err + return err } - shared_secret, err := ec_key.ECDH(req.ECDH) + root, err := schema.ReadRootErrorSignal(msg) if err != nil { - return ECDHSignal{}, nil, err + return err } - time_bytes, err := now.MarshalJSON() + header, err := root.Header() if err != nil { - return ECDHSignal{}, nil, err + return err } - sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) - - sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) + signal.Error, err = root.Error() if err != nil { - return ECDHSignal{}, nil, err - } - - return ECDHSignal{ - StringSignal: StringSignal{ - BaseSignal: NewBaseSignal(ECDHSignalType, Direct), - Str: "resp", - }, - Time: now, - EDDSA: node.Key.Public().(ed25519.PublicKey), - ECDH: ec_key.PublicKey(), - Signature: sig, - }, shared_secret, nil -} - -func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error { - earliest := now.Add(-window) - latest := now.Add(window) - - if sig.Time.Compare(earliest) == -1 { - return fmt.Errorf("TIME_TOO_LATE: %+v", sig.Time) - } else if sig.Time.Compare(latest) == 1 { - return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time) + return err } - - time_bytes, err := sig.Time.MarshalJSON() + signal.SignalHeader, err = DeserializeHeader(header) if err != nil { return err } - sig_data := append(sig.ECDH.Bytes(), time_bytes...) - - verified := ed25519.Verify(sig.EDDSA, sig_data, sig.Signature) - if verified == false { - return fmt.Errorf("Failed to verify signature") - } - return nil } - -type ECDHProxySignal struct { - BaseSignal - Source NodeID - Dest NodeID - IV []byte - Data []byte +func (signal *ErrorSignal) Permission() Tree { + return Tree{ + uint64(ErrorSignalType): nil, + } } - -func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) { - if shared_secret == nil { - return nil, fmt.Errorf("need shared_secret") +func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { + return &ErrorSignal{ + NewRespHeader(req_id, Direct), + fmt.Sprintf(fmt_string, args...), } +} - aes_key, err := aes.NewCipher(shared_secret[:32]) - if err != nil { - return nil, err +type ACLTimeoutSignal struct { + SignalHeader +} +func (signal *ACLTimeoutSignal) Header() *SignalHeader { + return &signal.SignalHeader +} +func (signal *ACLTimeoutSignal) Permission() Tree { + return Tree{ + uint64(ACLTimeoutSignalType): nil, } - - ser, err := SerializeSignal(signal, aes_key.BlockSize()) - if err != nil { - return nil, err +} +func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { + sig := &ACLTimeoutSignal{ + NewRespHeader(req_id, Direct), } + return sig +} - iv := make([]byte, aes_key.BlockSize()) - n, err := rand.Reader.Read(iv) - if err != nil { - return nil, err - } else if n != len(iv) { - return nil, fmt.Errorf("Not enough bytes read for IV") +type StatusSignal struct { + SignalHeader + Source NodeID + Status string +} +func (signal *StatusSignal) Header() *SignalHeader { + return &signal.SignalHeader +} +func (signal *StatusSignal) Permission() Tree { + return Tree{ + uint64(StatusSignalType): nil, + } +} +func NewStatusSignal(source NodeID, status string) *StatusSignal { + return &StatusSignal{ + NewSignalHeader(Up), + source, + status, } - - encrypter := cipher.NewCBCEncrypter(aes_key, iv) - encrypter.CryptBlocks(ser, ser) - - return &ECDHProxySignal{ - BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct), - Source: source, - Dest: dest, - IV: iv, - Data: ser, - }, nil } -type SignalHeader struct { - Magic uint32 - TypeHash uint64 - Length uint64 +type LinkSignal struct { + SignalHeader + NodeID + Action string +} +func (signal *LinkSignal) Header() *SignalHeader { + return &signal.SignalHeader } -const SIGNAL_SER_MAGIC uint32 = 0x753a64de -const SIGNAL_SER_HEADER_LENGTH = 20 -func SerializeSignal(signal Signal, block_size int) ([]byte, error) { - signal_ser, err := signal.Serialize() - if err != nil { - return nil, err - } +const ( + LinkActionBase = "LINK_ACTION" + LinkActionAdd = "ADD" +) - pad_req := 0 - if block_size > 0 { - pad := block_size - ((SIGNAL_SER_HEADER_LENGTH + len(signal_ser)) % block_size) - if pad != block_size { - pad_req = pad - } +func (signal *LinkSignal) Permission() Tree { + return Tree{ + uint64(LinkSignalType): Tree{ + Hash(LinkActionBase, signal.Action): nil, + }, } - - header := SignalHeader{ - Magic: SIGNAL_SER_MAGIC, - TypeHash: Hash(signal.Type()), - Length: uint64(len(signal_ser) + pad_req), +} +func NewLinkSignal(action string, id NodeID) Signal { + return &LinkSignal{ + NewSignalHeader(Direct), + id, + action, } +} - ser := make([]byte, SIGNAL_SER_HEADER_LENGTH + len(signal_ser) + pad_req) - binary.BigEndian.PutUint32(ser[0:4], header.Magic) - binary.BigEndian.PutUint64(ser[4:12], header.TypeHash) - binary.BigEndian.PutUint64(ser[12:20], header.Length) +type LockSignal struct { + SignalHeader + State string +} +func (signal *LockSignal) Header() *SignalHeader { + return &signal.SignalHeader +} - copy(ser[SIGNAL_SER_HEADER_LENGTH:], signal_ser) +const ( + LockStateBase = "LOCK_STATE" +) - return ser, nil +func (signal *LockSignal) Permission() Tree { + return Tree{ + uint64(LockSignalType): Tree{ + Hash(LockStateBase, signal.State): nil, + }, + } } -func ParseSignal(ctx *Context, data []byte) (Signal, error) { - if len(data) < SIGNAL_SER_HEADER_LENGTH { - return nil, fmt.Errorf("data shorter than header length") +func NewLockSignal(state string) *LockSignal { + return &LockSignal{ + NewSignalHeader(Direct), + state, } +} - header := SignalHeader{ - Magic: binary.BigEndian.Uint32(data[0:4]), - TypeHash: binary.BigEndian.Uint64(data[4:12]), - Length: binary.BigEndian.Uint64(data[12:20]), +type ReadSignal struct { + SignalHeader + Extensions map[ExtType][]string `json:"extensions"` +} +func (signal *ReadSignal) MarshalBinary() ([]byte, error) { + arena := capnp.SingleSegment(nil) + msg, seg, err := capnp.NewMessage(arena) + if err != nil { + return nil, err } - if header.Magic != SIGNAL_SER_MAGIC { - return nil, fmt.Errorf("signal magic mismatch 0x%x", header.Magic) + root, err := schema.NewRootReadSignal(seg) + if err != nil { + return nil, err } - left := len(data) - SIGNAL_SER_HEADER_LENGTH - if int(header.Length) != left { - return nil, fmt.Errorf("signal length mismatch %d/%d", header.Length, left) + header, err := root.NewHeader() + if err != nil { + return nil, err } - signal_def, exists := ctx.Signals[header.TypeHash] - if exists == false { - return nil, fmt.Errorf("0x%x is not a known signal type", header.TypeHash) + err = SerializeHeader(signal.SignalHeader, header) + if err != nil { + return nil, err } - signal, err := signal_def.Load(ctx, data[SIGNAL_SER_HEADER_LENGTH:]) + extensions, err := root.NewExtensions(int32(len(signal.Extensions))) if err != nil { return nil, err } - return signal, nil -} + i := 0 + for ext_type, fields := range(signal.Extensions) { + extension := extensions.At(i) + extension.SetType(uint64(ext_type)) + f, err := extension.NewFields(int32(len(fields))) + if err != nil { + return nil, err + } -func ParseECDHProxySignal(ctx *Context, signal *ECDHProxySignal, shared_secret []byte) (Signal, error) { - if shared_secret == nil { - return nil, fmt.Errorf("need shared_secret") - } + for j, field := range(fields) { + err := f.Set(j, field) + if err != nil { + return nil, err + } + } - aes_key, err := aes.NewCipher(shared_secret[:32]) - if err != nil { - return nil, err + i += 1 } - decrypter := cipher.NewCBCDecrypter(aes_key, signal.IV) - decrypted := make([]byte, len(signal.Data)) - decrypter.CryptBlocks(decrypted, signal.Data) + return msg.Marshal() +} +func (signal *ReadSignal) Header() *SignalHeader { + return &signal.SignalHeader +} - wrapped_signal, err := ParseSignal(ctx, decrypted) - if err != nil { - return nil, err +func (signal *ReadSignal) Permission() Tree { + ret := Tree{} + for ext, fields := range(signal.Extensions) { + field_tree := Tree{} + for _, field := range(fields) { + field_tree[Hash(FieldNameBase, field)] = nil + } + ret[uint64(ext)] = field_tree } + return Tree{uint64(ReadSignalType): ret} +} +func NewReadSignal(exts map[ExtType][]string) *ReadSignal { + return &ReadSignal{ + NewSignalHeader(Direct), + exts, + } +} - return wrapped_signal, nil +type ReadResultSignal struct { + SignalHeader + NodeID NodeID + NodeType NodeType + Extensions map[ExtType]map[string]SerializedValue +} +func (signal *ReadResultSignal) Header() *SignalHeader { + return &signal.SignalHeader +} +func (signal *ReadResultSignal) Permission() Tree { + return Tree{ + uint64(ReadResultSignalType): nil, + } } +func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal { + return &ReadResultSignal{ + NewRespHeader(req_id, Direct), + node_id, + node_type, + exts, + } +} + diff --git a/signal/go.mod b/signal/go.mod new file mode 100644 index 0000000..74f0555 --- /dev/null +++ b/signal/go.mod @@ -0,0 +1,10 @@ +module github.com/mekkanized/graphvent/signal + +go 1.21.0 + +require capnproto.org/go/capnp/v3 v3.0.0-alpha-29 + +require ( + golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 // indirect + zenhack.net/go/util v0.0.0-20230414204917-531d38494cf5 // indirect +) diff --git a/signal/signal.capnp b/signal/signal.capnp new file mode 100644 index 0000000..94c89b3 --- /dev/null +++ b/signal/signal.capnp @@ -0,0 +1,35 @@ +using Go = import "/go.capnp"; +@0x83a311a90429ef94; +$Go.package("signal"); +$Go.import("github.com/mekkanized/graphvent/signal"); + +struct SignalHeader { + direction @0 :UInt8; + id @1 :Data; + reqID @2 :Data; +} + +struct ErrorSignal { + header @0 :SignalHeader; + error @1 :Text; +} + +struct LinkSignal { + header @0 :SignalHeader; + action @1 :Text; + id @2 :Data; +} + +struct LockSignal { + header @0 :SignalHeader; + state @1 :Text; +} + +struct ReadSignal { + header @0 :SignalHeader; + extensions @1 :List(Extension); + struct Extension { + type @0 :UInt64; + fields @1 :List(Text); + } +}