diff --git a/context.go b/context.go index 96a3e57..33f7865 100644 --- a/context.go +++ b/context.go @@ -1,16 +1,17 @@ package graphvent import ( - "crypto/ecdh" - "crypto/sha512" - "encoding/binary" - "errors" - "fmt" - "reflect" - "runtime" - "sync" - - badger "github.com/dgraph-io/badger/v3" + "crypto/ecdh" + "crypto/sha512" + "encoding/binary" + "errors" + "fmt" + "reflect" + "runtime" + "sync" + "strconv" + + badger "github.com/dgraph-io/badger/v3" ) func Hash(base string, name string) uint64 { @@ -49,6 +50,7 @@ func NewSerializedType(name string) SerializedType { } const ( + TagBase = "GraphventTag" ExtTypeBase = "ExtType" NodeTypeBase = "NodeType" SignalTypeBase = "SignalType" @@ -91,7 +93,10 @@ var ( ExtensionType = NewSerializedType("extension") StringType = NewSerializedType("string") + IntType = NewSerializedType("int") Uint8Type = NewSerializedType("uint8") + Uint32Type = NewSerializedType("uint32") + Uint64Type = NewSerializedType("uint64") NodeKeyType = NewSerializedType("node_key") NodeNotFoundError = errors.New("Node not found in DB") @@ -107,8 +112,8 @@ type NodeInfo struct { Extensions []ExtType } -type TypeSerialize func(*Context,interface{}) ([]byte, error) -type TypeDeserialize func(*Context,[]byte) (interface{}, error) +type TypeSerialize func(*Context,uint64,reflect.Type,*reflect.Value) (SerializedValue, error) +type TypeDeserialize func(*Context,SerializedValue) (interface{}, []byte, error) type TypeInfo struct { Type reflect.Type Serialize TypeSerialize @@ -136,6 +141,9 @@ type Context struct { Types map[SerializedType]TypeInfo TypeReflects map[reflect.Type]SerializedType + Kinds map[reflect.Kind]KindInfo + KindTypes map[SerializedType]reflect.Kind + // Routing map to all the nodes local to this context nodeMapLock sync.RWMutex nodeMap map[NodeID]*Node @@ -207,6 +215,32 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType return nil } +func (ctx *Context)RegisterKind(kind reflect.Kind, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error { + _, exists := ctx.Kinds[kind] + if exists == true { + return fmt.Errorf("Cannot register kind %+v, kind already exists in context", kind) + } + _, exists = ctx.KindTypes[ctx_type] + if exists == true { + return fmt.Errorf("0x%x is already registered, cannot use for %+v", ctx_type, kind) + } + if deserialize == nil { + return fmt.Errorf("Cannot register field without deserialize function") + } + if serialize == nil { + return fmt.Errorf("Cannot register field without serialize function") + } + + ctx.Kinds[kind] = KindInfo{ + ctx_type, + serialize, + deserialize, + } + ctx.KindTypes[ctx_type] = kind + + return nil +} + func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error { _, exists := ctx.Types[ctx_type] if exists == true { @@ -287,31 +321,10 @@ func (ctx *Context) Send(messages Messages) error { return nil } -type defaultKind struct { +type KindInfo struct { Type SerializedType - Serialize func(interface{})([]byte, error) - Deserialize func([]byte)(interface{}, error) -} - -var defaultKinds = map[reflect.Kind]defaultKind{ - reflect.Int: { - Deserialize: func(data []byte)(interface{}, error){ - if len(data) != 8 { - return nil, fmt.Errorf("invalid length: %d/8", len(data)) - } - return int(binary.BigEndian.Uint64(data)), nil - }, - Serialize: func(val interface{})([]byte, error){ - i, ok := val.(int) - if ok == false { - return nil, fmt.Errorf("invalid type %+v", reflect.TypeOf(val)) - } else { - bytes := make([]byte, 8) - binary.BigEndian.PutUint64(bytes, uint64(i)) - return bytes, nil - } - }, - }, + Serialize TypeSerialize + Deserialize TypeDeserialize } type SerializedValue struct { @@ -321,33 +334,364 @@ type SerializedValue struct { func SerializeValue(ctx *Context, value reflect.Value) (SerializedValue, error) { val, err := serializeValue(ctx, value.Type(), &value) - ctx.Log.Logf("serialize", "SERIALIZED_VALUE(%+v): %+v - %s", value, val, err) + ctx.Log.Logf("serialize", "SERIALIZED_VALUE(%+v): %+v - %+v", value.Type(), val.TypeStack, val.Data) return val, err } func serializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (SerializedValue, error) { var ctx_type uint64 = 0x00 ctype, exists := ctx.TypeReflects[t] - ctx.Log.Logf("serialize", "TYPE_REFLECTS: %+v", ctx.TypeReflects) if exists == true { type_info := ctx.Types[ctype] ctx_type = uint64(ctype) - val_ser, err := type_info.Serialize(ctx, value.Interface()) + if type_info.Serialize != nil { + return type_info.Serialize(ctx, ctx_type, t, value) + } + } + + kind := t.Kind() + kind_info, handled := ctx.Kinds[kind] + if handled == false { + return SerializedValue{}, fmt.Errorf("Don't know how to serialize kind %+v", kind) + } else if ctx_type == 0x00 { + ctx_type = uint64(kind_info.Type) + } + + return kind_info.Serialize(ctx, ctx_type, t, value) + +} + +func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { + if ext == nil { + return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension") + } + ext_value := reflect.ValueOf(ext).Elem() + field := ext_value.FieldByName(field_name) + if field.IsValid() == false { + return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext) + } else { + return SerializeValue(ctx, field) + } +} + +func (value SerializedValue) MarshalBinary() ([]byte, error) { + data := make([]byte, value.SerializedSize()) + binary.BigEndian.PutUint64(data[0:8], uint64(len(value.TypeStack))) + binary.BigEndian.PutUint64(data[8:16], uint64(len(value.Data))) + + for i, t := range(value.TypeStack) { + type_start := (i+2)*8 + type_end := (i+3)*8 + binary.BigEndian.PutUint64(data[type_start:type_end], t) + } + + return append(data, value.Data...), nil +} + +func (value SerializedValue) SerializedSize() uint64 { + return uint64((len(value.TypeStack) + 2) * 8) +} + +func ParseSerializedValue(ctx *Context, data []byte) (SerializedValue, error) { + if len(data) < 8 { + return SerializedValue{}, fmt.Errorf("SerializedValue required to have at least 8 bytes when serialized") + } + num_types := int(binary.BigEndian.Uint64(data[0:8])) + data_size := int(binary.BigEndian.Uint64(data[8:16])) + type_stack := make([]uint64, num_types) + for i := 0; i < num_types; i += 1 { + type_start := (i+2) * 8 + type_end := (i+3) * 8 + type_stack[i] = binary.BigEndian.Uint64(data[type_start:type_end]) + } + + types_end := 8*(num_types + 1) + return SerializedValue{ + type_stack, + data[types_end:(types_end+data_size)], + }, nil +} + +func DeserializeValue(ctx *Context, value SerializedValue, n int) ([]interface{}, []byte, error) { + ret := make([]interface{}, n) + + var deserialize TypeDeserialize = nil + + ctx_type := value.TypeStack[0] + type_info, exists := ctx.Types[SerializedType(ctx_type)] + if exists == true { + deserialize = type_info.Deserialize + } else { + kind, exists := ctx.KindTypes[SerializedType(ctx_type)] + if exists == false { + return nil, nil, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) + } + kind_info := ctx.Kinds[kind] + deserialize = kind_info.Deserialize + } + + remaining_data := value.Data + for i := 0; i < n; i += 1 { + var elem interface{} = nil + var err error = nil + elem, remaining_data, err = deserialize(ctx, value) + if err != nil { + return nil, nil, err + } + if len(remaining_data) == 0 { + remaining_data = nil + } + ret[i] = elem + } + return ret, remaining_data, nil +} + +// Create a new Context with the base library content added +func NewContext(db * badger.DB, log Logger) (*Context, error) { + ctx := &Context{ + DB: db, + Log: log, + Policies: map[PolicyType]reflect.Type{}, + PolicyTypes: map[reflect.Type]PolicyType{}, + Extensions: map[ExtType]ExtensionInfo{}, + ExtensionTypes: map[reflect.Type]ExtType{}, + Signals: map[SignalType]reflect.Type{}, + SignalTypes: map[reflect.Type]SignalType{}, + Nodes: map[NodeType]NodeInfo{}, + nodeMap: map[NodeID]*Node{}, + Types: map[SerializedType]TypeInfo{}, + TypeReflects: map[reflect.Type]SerializedType{}, + Kinds: map[reflect.Kind]KindInfo{}, + KindTypes: map[SerializedType]reflect.Kind{}, + } + + var err error + err = ctx.RegisterKind(reflect.Pointer, NewSerializedType("pointer"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { + var data []byte + var elem_value *reflect.Value = nil + if value == nil { + data = nil + } else if value.IsZero() { + data = []byte{0x01} + } else { + data = []byte{0x00} + ev := value.Elem() + elem_value = &ev + } + elem, err := serializeValue(ctx, reflect_type.Elem(), elem_value) + if err != nil { + return SerializedValue{}, err + } + if elem.Data != nil { + data = append(data, elem.Data...) + } + return SerializedValue{ + append([]uint64{ctx_type}, elem.TypeStack...), + data, + }, nil + }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { + return nil, nil, fmt.Errorf("deserialize pointer unimplemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.Struct, NewSerializedType("struct"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var m map[int][]byte = nil + if value != nil { + m = map[int][]byte{} + } + num_fields := 0 + for _, field := range(reflect.VisibleFields(reflect_type)) { + gv_tag, tagged_gv := field.Tag.Lookup("gv") + if tagged_gv == false { + continue + } else if gv_tag == "" { + continue + } else if m != nil { + field_index, err := strconv.Atoi(gv_tag) + if err != nil { + return SerializedValue{}, err + } + num_fields += 1 + + field_value := value.FieldByIndex(field.Index) + field_ser, err := serializeValue(ctx, field.Type, &field_value) + if err != nil { + return SerializedValue{}, err + } + + m[field_index], err = field_ser.MarshalBinary() + if err != nil { + return SerializedValue{}, nil + } + } + } + field_list := make([][]byte, num_fields) + for i := range(field_list) { + var exists bool = false + field_list[i], exists = m[i] + if exists == false { + return SerializedValue{}, fmt.Errorf("%+v missing gv:%d", reflect_type, i) + } + } + + list_value := reflect.ValueOf(field_list) + list_serial, err := serializeValue(ctx, list_value.Type(), &list_value) if err != nil { return SerializedValue{}, err } return SerializedValue{ []uint64{ctx_type}, - val_ser, + list_serial.Data, }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize struct not implemented") + }) + if err != nil { + return nil, err } - kind := t.Kind() - switch kind { - case reflect.Map: - if ctx_type == 0x00 { - ctx_type = uint64(MapType) + err = ctx.RegisterKind(reflect.Int, NewSerializedType("int"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var data []byte = nil + if value != nil { + data = make([]byte, 8) + binary.BigEndian.PutUint64(data, value.Uint()) + } + return SerializedValue{ + []uint64{ctx_type}, + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + if len(value.Data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("invalid length: %d/8", len(value.Data)) + } + remaining_data := value.Data[8:] + if len(remaining_data) == 0 { + remaining_data = nil + } + return int(binary.BigEndian.Uint64(value.Data[0:8])), remaining_data, nil + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.Uint32, NewSerializedType("uint32"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + data := make([]byte, 4) + if value != nil { + binary.BigEndian.PutUint32(data, uint32(value.Uint())) + } + return SerializedValue{ + []uint64{ctx_type}, + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize uint32 unimplemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.String, NewSerializedType("string"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + if value == nil { + return SerializedValue{ + []uint64{ctx_type}, + nil, + }, nil + } + + data := make([]byte, 8) + str := value.String() + binary.BigEndian.PutUint64(data, uint64(len(str))) + return SerializedValue{ + []uint64{uint64(ctx_type)}, + append(data, []byte(str)...), + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize string unimplemented") + }) + if err != nil { + return nil, err + } + + + err = ctx.RegisterKind(reflect.Array, NewSerializedType("array"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var data []byte + if value == nil { + data = nil + } else if value.Len() == 0 { + data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + } else { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(value.Len())) + + var type_stack []uint64 = nil + for i := 0; i < value.Len(); i += 1 { + val := value.Index(i) + element, err := serializeValue(ctx, reflect_type.Elem(), &val) + if err != nil { + return SerializedValue{}, err + } + if type_stack == nil { + type_stack = append([]uint64{ctx_type}, element.TypeStack...) + } + data = append(data, element.Data...) + } + } + + elem, err := serializeValue(ctx, reflect_type.Elem(), nil) + if err != nil { + return SerializedValue{}, err } + + return SerializedValue{ + append([]uint64{ctx_type}, elem.TypeStack...), + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize array unimplemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.Interface, NewSerializedType("interface"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var data []byte + type_stack := []uint64{} + if value == nil { + data = nil + } else if value.IsZero() { + return SerializedValue{}, fmt.Errorf("Cannot serialize nil interfaces") + } else { + elem_value := value.Elem() + elem, err := serializeValue(ctx, value.Elem().Type(), &elem_value) + if err != nil { + return SerializedValue{}, err + } + data = elem.Data + type_stack = elem.TypeStack + } + return SerializedValue{ + append([]uint64{ctx_type}, type_stack...), + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize interface unimplemented") + }) + if err != nil { + return nil, err + } + + + err = ctx.RegisterKind(reflect.Map, NewSerializedType("map"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ var data []byte if value == nil { data = nil @@ -367,11 +711,11 @@ func serializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali key_value := map_iter.Key() val_value := map_iter.Value() - key, err := serializeValue(ctx, t.Key(), &key_value) + key, err := serializeValue(ctx, reflect_type.Key(), &key_value) if err != nil { return SerializedValue{}, err } - val, err := serializeValue(ctx, t.Elem(), &val_value) + val, err := serializeValue(ctx, reflect_type.Elem(), &val_value) if err != nil { return SerializedValue{}, err } @@ -398,11 +742,11 @@ func serializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali data, }, nil } - key, err := serializeValue(ctx, t.Key(), nil) + key, err := serializeValue(ctx, reflect_type.Key(), nil) if err != nil { return SerializedValue{}, err } - elem, err := serializeValue(ctx, t.Elem(), nil) + elem, err := serializeValue(ctx, reflect_type.Elem(), nil) if err != nil { return SerializedValue{}, err } @@ -413,10 +757,51 @@ func serializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali type_stack, data, }, nil - case reflect.Slice: - if ctx_type == 0x00 { - ctx_type = uint64(SliceType) + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize map unimplemented") + }) + if err != nil { + return nil, err + } + + + err = ctx.RegisterKind(reflect.Uint8, NewSerializedType("uint8"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var data []byte = nil + if value != nil { + data = []byte{uint8(value.Uint())} + } + return SerializedValue{ + []uint64{ctx_type}, + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize uint8 unimplemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.Uint64, NewSerializedType("uint64"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ + var data []byte = nil + if value != nil { + data = make([]byte, 8) + binary.BigEndian.PutUint64(data, value.Uint()) } + return SerializedValue{ + []uint64{ctx_type}, + data, + }, nil + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("deserialize uint64 unimplemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterKind(reflect.Slice, NewSerializedType("slice"), + func(ctx *Context, ctx_type uint64, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ var data []byte if value == nil { data = nil @@ -427,217 +812,74 @@ func serializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali } else { data := make([]byte, 8) binary.BigEndian.PutUint64(data, uint64(value.Len())) - var elem SerializedValue + var type_stack []uint64 for i := 0; i < value.Len(); i += 1 { val := value.Index(i) - element, err := serializeValue(ctx, t.Elem(), &val) + element, err := serializeValue(ctx, reflect_type.Elem(), &val) if err != nil { return SerializedValue{}, err } - if i == 0 { - elem = element + if type_stack == nil { + type_stack = append([]uint64{ctx_type}, element.TypeStack...) } - data = append(data, elem.Data...) + data = append(data, element.Data...) } return SerializedValue{ - append([]uint64{ctx_type}, elem.TypeStack...), + append([]uint64{ctx_type}, type_stack...), data, }, nil } - elem, err := serializeValue(ctx, t.Elem(), nil) + elem, err := serializeValue(ctx, reflect_type.Elem(), nil) if err != nil { return SerializedValue{}, err } return SerializedValue{ - append([]uint64{ctx_type}, elem.TypeStack...), + elem.TypeStack, data, }, nil - case reflect.Pointer: - if ctx_type == 0x00 { - ctx_type = uint64(PointerType) - } - var data []byte - var elem_value *reflect.Value = nil - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0x01} - } else { - data = []byte{0x00} - ev := value.Elem() - elem_value = &ev - } - elem, err := serializeValue(ctx, t.Elem(), elem_value) - if err != nil { - return SerializedValue{}, err - } - if elem.Data != nil { - data = append(data, elem.Data...) + }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ + return nil, nil, fmt.Errorf("not implemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(Up), NewSerializedType("SignalDirection"), + func(ctx *Context, ctx_type uint64, t reflect.Type, value *reflect.Value) (SerializedValue, error) { + var data []byte = nil + if value != nil { + val := value.Interface().(SignalDirection) + data = []byte{byte(val)} } return SerializedValue{ - append([]uint64{uint64(ctx_type)}, elem.TypeStack...), + []uint64{ctx_type}, data, }, nil - case reflect.String: - if ctx_type == 0x00 { - ctx_type = uint64(StringType) - } - if value == nil { - return SerializedValue{ - []uint64{ctx_type}, - nil, - }, nil - } + }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { + return nil, nil, fmt.Errorf("unimplemented") + }) + if err != nil { + return nil, err + } - data := make([]byte, 8) - str := value.String() - binary.BigEndian.PutUint64(data, uint64(len(str))) - return SerializedValue{ - []uint64{uint64(ctx_type)}, - append(data, []byte(str)...), - }, nil - case reflect.Uint8: - if ctx_type == 0x00 { - ctx_type = uint64(Uint8Type) + err = ctx.RegisterType(reflect.TypeOf(ReqState(0)), NewSerializedType("ReqState"), + func(ctx *Context, ctx_type uint64, t reflect.Type, value *reflect.Value) (SerializedValue, error) { + var data []byte = nil + if value != nil { + val := value.Interface().(ReqState) + data = []byte{byte(val)} } return SerializedValue{ - []uint64{uint64(ctx_type)}, - []byte{uint8(value.Uint())}, + []uint64{ctx_type}, + data, }, nil - default: - return SerializedValue{}, fmt.Errorf("unhandled kind: %+v - %+v", kind, t) - } -} - -/* - default: - kind_def, handled := defaultKinds[kind] - if handled == false { - ctx_type, handled := ctx.TypeReflects[value.Type()] - if handled == false { - err = fmt.Errorf("%+v is not a handled reflect type", value.Type()) - break - } - type_info, handled := ctx.Types[ctx_type] - if handled == false { - err = fmt.Errorf("%+v is not a handled reflect type(INTERNAL_ERROR)", value.Type()) - break - } - field_ser, err := type_info.Serialize(ctx, value.Interface()) - if err != nil { - err = fmt.Errorf(err.Error()) - break - } - ret = SerializedValue{ - []uint64{uint64(ctx_type)}, - field_ser, - } - } - field_ser, err := kind_def.Serialize(value.Interface()) - if err != nil { - err = fmt.Errorf(err.Error()) - } else { - ret = SerializedValue{ - []uint64{uint64(kind_def.Type)}, - field_ser, - } - } -*/ - -func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { - if ext == nil { - return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension") - } - ext_value := reflect.ValueOf(ext).Elem() - field := ext_value.FieldByName(field_name) - if field.IsValid() == false { - return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext) - } else { - return SerializeValue(ctx, field) - } -} - -func SerializeSignal(ctx *Context, signal Signal, ctx_type SignalType) (SerializedValue, error) { - return SerializedValue{}, nil -} - -func SerializeExtension(ctx *Context, ext Extension, ctx_type ExtType) (SerializedValue, error) { - if ext == nil { - return SerializedValue{}, fmt.Errorf("Cannot serialize nil Extension ") - } - ext_type := reflect.TypeOf(ext).Elem() - ext_value := reflect.ValueOf(ext).Elem() - - m := map[string]SerializedValue{} - for _, field := range(reflect.VisibleFields(ext_type)) { - ext_tag, tagged_ext := field.Tag.Lookup("ext") - if tagged_ext == false { - continue - } else { - field_value := ext_value.FieldByIndex(field.Index) - var err error - m[ext_tag], err = SerializeValue(ctx, field_value) - if err != nil { - return SerializedValue{}, err - } - } - } - map_value := reflect.ValueOf(m) - map_ser, err := SerializeValue(ctx, map_value) + }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { + return nil, nil, fmt.Errorf("unimplemented") + }) if err != nil { - return SerializedValue{}, err - } - return SerializedValue{ - append([]uint64{uint64(ctx_type)}, map_ser.TypeStack...), - map_ser.Data, - }, nil -} - -func (value SerializedValue) MarshalBinary() ([]byte, error) { - return nil, fmt.Errorf("SerializedValue.MarshalBinary Undefined") -} - -func ParseSerializedValue(ctx *Context, data []byte) (SerializedValue, []byte, error) { - return SerializedValue{}, nil, fmt.Errorf("ParseSerializedValue Undefined") -} - -func DeserializeValue(ctx *Context, value SerializedValue) (interface{}, error) { - return nil, fmt.Errorf("DeserializeValue Undefined") -} - -// Create a new Context with the base library content added -func NewContext(db * badger.DB, log Logger) (*Context, error) { - ctx := &Context{ - DB: db, - Log: log, - Policies: map[PolicyType]reflect.Type{}, - PolicyTypes: map[reflect.Type]PolicyType{}, - Extensions: map[ExtType]ExtensionInfo{}, - ExtensionTypes: map[reflect.Type]ExtType{}, - Signals: map[SignalType]reflect.Type{}, - SignalTypes: map[reflect.Type]SignalType{}, - Nodes: map[NodeType]NodeInfo{}, - nodeMap: map[NodeID]*Node{}, - Types: map[SerializedType]TypeInfo{}, - TypeReflects: map[reflect.Type]SerializedType{}, + return nil, err } - var err error - err = ctx.RegisterType(reflect.TypeOf(SerializedValue{}), NewSerializedType("SerializedValue"), - func(ctx *Context, val interface{}) ([]byte, error) { - value := val.(SerializedValue) - return value.MarshalBinary() - }, func(ctx *Context, data []byte) (interface{}, error) { - value, data, err := ParseSerializedValue(ctx, data) - if err != nil { - return nil, err - } - if data != nil { - return nil, fmt.Errorf("%+v remaining after parse", data) - } - return value, nil - }) - err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index 771b7dd..cb29881 100644 --- a/gql_test.go +++ b/gql_test.go @@ -16,6 +16,7 @@ import ( "bytes" "golang.org/x/net/websocket" "github.com/google/uuid" + "reflect" ) func TestGQLServer(t *testing.T) { @@ -210,7 +211,7 @@ func TestGQLServer(t *testing.T) { } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) + ctx := logTestContext(t, []string{"test", "signal", "node"}) TestUserNodeType := NewNodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) @@ -239,10 +240,12 @@ func TestGQLDB(t *testing.T) { }) fatalErr(t, err) - ser1, err := gql.Serialize(ctx) - ser2, err := u1.Serialize(ctx) - ctx.Log.Logf("test", "SER_1: \n%s\n\n", ser1) - ctx.Log.Logf("test", "SER_2: \n%s\n\n", ser2) + ser1, err := SerializeValue(ctx, reflect.ValueOf(gql)) + fatalErr(t, err) + ctx.Log.Logf("test", "SER_1: \n%+v\n\n", ser1) + ser2, err := SerializeValue(ctx, reflect.ValueOf(u1)) + fatalErr(t, err) + ctx.Log.Logf("test", "SER_2: \n%+v\n\n", ser2) // Clear all loaded nodes from the context so it loads them from the database ctx.nodeMap = map[NodeID]*Node{} diff --git a/group.go b/group.go index 0237e48..14321b2 100644 --- a/group.go +++ b/group.go @@ -5,7 +5,7 @@ import ( ) type GroupExt struct { - Members map[NodeID]string + Members map[NodeID]string `gv:"0"` } func (ext *GroupExt) Type() ExtType { diff --git a/listener.go b/listener.go index 603585c..357c866 100644 --- a/listener.go +++ b/listener.go @@ -1,6 +1,7 @@ package graphvent import ( + "reflect" "encoding/json" ) @@ -31,7 +32,7 @@ func (listener *ListenerExt) Type() ExtType { // Send the signal to the channel, logging an overflow if it occurs func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { - ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) + ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, reflect.TypeOf(signal)) select { case ext.Chan <- signal: default: diff --git a/lockable.go b/lockable.go index ccacfa8..cd0640c 100644 --- a/lockable.go +++ b/lockable.go @@ -14,11 +14,11 @@ const ( ) type LockableExt struct{ - State ReqState `ext:""` - ReqID *uuid.UUID `ext:""` - Owner *NodeID `ext:""` - PendingOwner *NodeID `ext:""` - Requirements map[NodeID]ReqState `ext:""` + State ReqState `gv:"0"` + ReqID *uuid.UUID `gv:"1"` + Owner *NodeID `gv:"2"` + PendingOwner *NodeID `gv:"3"` + Requirements map[NodeID]ReqState `gv:"4"` } func (ext *LockableExt) Type() ExtType { diff --git a/node.go b/node.go index aaa7c7c..e720f1d 100644 --- a/node.go +++ b/node.go @@ -7,13 +7,11 @@ import ( "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" "fmt" - "encoding/binary" "sync/atomic" "crypto" "crypto/ed25519" "crypto/sha512" "crypto/rand" - "crypto/x509" ) const ( @@ -90,25 +88,25 @@ type PendingSignal struct { // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { - Key ed25519.PrivateKey + Key ed25519.PrivateKey `gv:"0"` ID NodeID - Type NodeType - Extensions map[ExtType]Extension - Policies map[PolicyType]Policy + Type NodeType `gv:"1"` + Extensions map[ExtType]Extension `gv:"3"` + Policies map[PolicyType]Policy `gv:"4"` - PendingACLs map[uuid.UUID]PendingACL - PendingSignals map[uuid.UUID]PendingSignal + PendingACLs map[uuid.UUID]PendingACL `gv:"6"` + PendingSignals map[uuid.UUID]PendingSignal `gv:"7"` // Channel for this node to receive messages from the Context MsgChan chan *Message // Size of MsgChan - BufferSize uint32 + BufferSize uint32 `gv:"2"` // Channel for this node to process delayed signals TimeoutChan <-chan time.Time Active atomic.Bool - SignalQueue []QueuedSignal + SignalQueue []QueuedSignal `gv:"5"` NextSignal *QueuedSignal } @@ -251,12 +249,7 @@ func nodeLoop(ctx *Context, node *Node) error { select { case msg := <- node.MsgChan: ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal) - signal_type, exists := ctx.SignalTypes[reflect.TypeOf(msg.Signal).Elem()] - if exists == false { - ctx.Log.Logf("signal", "SIGNAL_NOT_REGISTERED: %+v", reflect.TypeOf(msg.Signal).Elem()) - } - - signal_ser, err := SerializeSignal(ctx, signal, signal_type) + signal_ser, err := SerializeValue(ctx, reflect.ValueOf(msg.Signal)) if err != nil { ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal) } @@ -280,6 +273,8 @@ func nodeLoop(ctx *Context, node *Node) error { sig_data = append(sig_data, ser...) validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) if validated == false { + println(fmt.Sprintf("SIGNAL: %s", msg.Signal)) + println(fmt.Sprintf("VERIFY_DIGEST: %+v", sig_data)) ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg) continue } @@ -393,7 +388,7 @@ func nodeLoop(ctx *Context, node *Node) error { switch sig := signal.(type) { case *StopSignal: msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "stopped"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewStatusSignal(node.ID, "stopped"), source) ctx.Send(msgs) node.Process(ctx, node.ID, NewStatusSignal(node.ID, "stopped")) run = false @@ -403,6 +398,8 @@ func nodeLoop(ctx *Context, node *Node) error { msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "read_done"), source) ctx.Send(msgs) + default: + println(fmt.Sprintf("NOT_SPECIAL_SIGNAL: %+v", reflect.TypeOf(sig))) } node.Process(ctx, source, signal) @@ -442,12 +439,7 @@ func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateK } func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { - signal_type, exists := ctx.SignalTypes[reflect.TypeOf(signal)] - if exists == false { - return nil, fmt.Errorf("Cannot put %+v in a message, not a known signal type", reflect.TypeOf(signal)) - } - - signal_ser, err := SerializeSignal(ctx, signal, signal_type) + signal_ser, err := SerializeValue(ctx, reflect.ValueOf(signal)) if err != nil { return nil, err } @@ -527,61 +519,12 @@ func GetExt[T Extension](node *Node, ext_type ExtType) (T, error) { return ret, nil } -func (node *Node) Serialize(ctx *Context) (SerializedValue, error) { - if node == nil { - return SerializedValue{}, fmt.Errorf("Cannot serialize nil Node") - } - - node_bytes := make([]byte, 8 * 3) - binary.BigEndian.PutUint64(node_bytes[0:8], uint64(len(node.Extensions))) - binary.BigEndian.PutUint64(node_bytes[8:16], uint64(len(node.Policies))) - binary.BigEndian.PutUint64(node_bytes[16:24], uint64(len(node.SignalQueue))) - - key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) - if err != nil { - return SerializedValue{}, err - } - - key_val := SerializedValue{ - TypeStack: []uint64{uint64(NodeKeyType)}, - Data: key_bytes, - } - key_ser, err := key_val.MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - node_bytes = append(node_bytes, key_ser...) - - for ext_type, ext := range(node.Extensions) { - ctx.Log.Logf("serialize", "SERIALIZING_EXTENSION: %+v", ext) - ext_ser, err := SerializeExtension(ctx, ext, ext_type) - if err != nil { - return SerializedValue{}, err - } - ext_bytes, err := ext_ser.MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - ctx.Log.Logf("serialize", "SERIALIZED_EXTENSION: %+v", ext_bytes) - - node_bytes = append(node_bytes, ext_bytes...) - } - - node_value := SerializedValue{ - TypeStack: []uint64{uint64(node.Type)}, - Data: node_bytes, - } - - return node_value, nil -} - func KeyID(pub ed25519.PublicKey) NodeID { id := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) return NodeID(id) } // Create a new node in memory and start it's event loop -// TODO: Change panics to errors func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) (*Node, error) { var err error var public ed25519.PublicKey @@ -647,7 +590,6 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si Type: node_type, Extensions: ext_map, Policies: policies, - //TODO serialize/deserialize these PendingACLs: map[uuid.UUID]PendingACL{}, PendingSignals: map[uuid.UUID]PendingSignal{}, MsgChan: make(chan *Message, buffer_size), @@ -672,7 +614,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si func WriteNode(ctx *Context, node *Node) error { ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) - node_serialized, err := node.Serialize(ctx) + node_serialized, err := SerializeValue(ctx, reflect.ValueOf(node)) if err != nil { return err } @@ -694,7 +636,6 @@ func WriteNode(ctx *Context, node *Node) error { }) } -//TODO: fix after capnp func LoadNode(ctx * Context, id NodeID) (*Node, error) { ctx.Log.Logf("db", "LOADING_NODE: %s", id) var bytes []byte @@ -720,18 +661,27 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } - num_extensions := binary.BigEndian.Uint64(bytes[0:8]) - num_policies := binary.BigEndian.Uint64(bytes[8:16]) - num_signals := binary.BigEndian.Uint64(bytes[16:24]) - print(num_extensions) - print(num_policies) - print(num_signals) + node_value, err := ParseSerializedValue(ctx, bytes) + if err != nil { + return nil, err + } + node_if, remaining, err := DeserializeValue(ctx, node_value, 1) + if err != nil { + return nil, err + } + + if remaining != nil { + return nil, fmt.Errorf("%d bytes left after desrializing *Node", len(remaining)) + } - /* - ctx.AddNode(id, node) + node, ok := node_if[0].(*Node) + if ok == false { + return nil, fmt.Errorf("Deserialized %+v when expecting *Node", reflect.TypeOf(node_if).Elem()) + } + + ctx.AddNode(id, node) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) go runNode(ctx, node) - */ return nil, nil } diff --git a/node_test.go b/node_test.go index 2e1f222..20b734c 100644 --- a/node_test.go +++ b/node_test.go @@ -8,12 +8,27 @@ import ( ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize"}) + ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize", "listener"}) node_type := NewNodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) - node, err := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil)) + node_listener := NewListenerExt(10) + node, err := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener) + fatalErr(t, err) + + _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "started" && sig.Source == node.ID + }) + + msgs := Messages{} + msgs = msgs.Add(ctx, node.ID, node.Key, NewStopSignal(), node.ID) + err = ctx.Send(msgs) + fatalErr(t, err) + + _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { + return sig.Status == "stopped" && sig.Source == node.ID + }) fatalErr(t, err) ctx.nodeMap = map[NodeID]*Node{} diff --git a/policy.go b/policy.go index 9bbdf79..6ae2cf9 100644 --- a/policy.go +++ b/policy.go @@ -64,12 +64,12 @@ func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingA return Deny } - reqs_if, err := DeserializeValue(ctx, reqs_ser) + reqs_if, _, err := DeserializeValue(ctx, reqs_ser, 1) if err != nil { return Deny } - requirements, ok := reqs_if.(map[NodeID]ReqState) + requirements, ok := reqs_if[0].(map[NodeID]ReqState) if ok == false { return Deny } @@ -113,17 +113,17 @@ func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, s return Deny } - members_if, err := DeserializeValue(ctx, members_ser) + members_if, _, err := DeserializeValue(ctx, members_ser, 1) if err != nil { return Deny } - members, ok := members_if.(map[NodeID]string) + members, ok := members_if[0].(map[NodeID]string) if ok == false { return Deny } - for member, _ := range(members) { + for member := range(members) { if member == current.Principal { return policy.NodeRules[sig.NodeID].Allows(current.Action) } @@ -139,7 +139,7 @@ func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action T if id == node.ID { ext, err := GetExt[*GroupExt](node, GroupExtType) if err == nil { - for member, _ := range(ext.Members) { + for member := range(ext.Members) { if member == principal_id { if rule.Allows(action) == Allow { return nil, Allow diff --git a/signal.go b/signal.go index c4a266e..2705607 100644 --- a/signal.go +++ b/signal.go @@ -17,9 +17,9 @@ const ( ) type SignalHeader struct { - Direction SignalDirection - ID uuid.UUID - ReqID uuid.UUID + Direction SignalDirection `gv:"0"` + ID uuid.UUID `gv:"1"` + ReqID uuid.UUID `gv:"2"` } type Signal interface {