From 799b6404ddb9feb2ecc96e3e3e991a64391456df Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 2 Sep 2023 18:49:37 -0600 Subject: [PATCH] Added deserialize to pointer --- context.go | 199 +++++++++++++++++++++++++++++++++++++---------------- node.go | 8 +-- policy.go | 8 +-- 3 files changed, 148 insertions(+), 67 deletions(-) diff --git a/context.go b/context.go index 33f7865..3399e54 100644 --- a/context.go +++ b/context.go @@ -84,20 +84,7 @@ var ( PerNodePolicyType = NewPolicyType("PER_NODE") AllNodesPolicyType = NewPolicyType("ALL_NODES") - StructType = NewSerializedType("struct") - SliceType = NewSerializedType("slice") - ArrayType = NewSerializedType("array") - PointerType = NewSerializedType("pointer") - MapType = NewSerializedType("map") - ErrorType = NewSerializedType("error") - ExtensionType = NewSerializedType("extension") - - StringType = NewSerializedType("string") - IntType = NewSerializedType("int") - Uint8Type = NewSerializedType("uint8") - Uint32Type = NewSerializedType("uint32") - Uint64Type = NewSerializedType("uint64") - NodeKeyType = NewSerializedType("node_key") + ErrorType = NewSerializedType("ERROR") NodeNotFoundError = errors.New("Node not found in DB") ECDH = ecdh.X25519() @@ -113,7 +100,7 @@ type NodeInfo struct { } type TypeSerialize func(*Context,uint64,reflect.Type,*reflect.Value) (SerializedValue, error) -type TypeDeserialize func(*Context,SerializedValue) (interface{}, []byte, error) +type TypeDeserialize func(*Context,[]uint64,[]byte) (reflect.Type, *reflect.Value, []byte, error) type TypeInfo struct { Type reflect.Type Serialize TypeSerialize @@ -405,45 +392,64 @@ func ParseSerializedValue(ctx *Context, data []byte) (SerializedValue, error) { type_stack[i] = binary.BigEndian.Uint64(data[type_start:type_end]) } - types_end := 8*(num_types + 1) + types_end := 8*(num_types + 2) return SerializedValue{ type_stack, data[types_end:(types_end+data_size)], }, nil } -func DeserializeValue(ctx *Context, value SerializedValue, n int) ([]interface{}, []byte, error) { - ret := make([]interface{}, n) +func DeserializeValue(ctx *Context, type_stack []uint64, data []byte, n int) (reflect.Type, []reflect.Value, []byte, error) { + ret := make([]reflect.Value, n) var deserialize TypeDeserialize = nil + var reflect_type reflect.Type = nil + + ctx_type := type_stack[0] + type_stack = type_stack[1:] - ctx_type := value.TypeStack[0] type_info, exists := ctx.Types[SerializedType(ctx_type)] if exists == true { deserialize = type_info.Deserialize + reflect_type = type_info.Type } else { kind, exists := ctx.KindTypes[SerializedType(ctx_type)] if exists == false { - return nil, nil, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) + return nil, nil, nil, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) } kind_info := ctx.Kinds[kind] deserialize = kind_info.Deserialize } - remaining_data := value.Data - for i := 0; i < n; i += 1 { - var elem interface{} = nil - var err error = nil - elem, remaining_data, err = deserialize(ctx, value) + remaining_data := data + var value_type reflect.Type = nil + var err error = nil + if data == nil { + reflect_type, _, _, err = deserialize(ctx, type_stack, nil) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - if len(remaining_data) == 0 { - remaining_data = nil + } else { + for i := 0; i < n; i += 1 { + var elem *reflect.Value + var elem_type reflect.Type = nil + elem_type, elem, remaining_data, err = deserialize(ctx, type_stack, remaining_data) + if err != nil { + return nil, nil, nil, err + } + if value_type == nil { + value_type = elem_type + } + if len(remaining_data) == 0 { + remaining_data = nil + } + if elem == nil { + return nil, nil, nil, fmt.Errorf("root deserialize returned no value") + } + ret[i] = *elem } - ret[i] = elem } - return ret, remaining_data, nil + return reflect_type, ret, remaining_data, nil } // Create a new Context with the base library content added @@ -490,8 +496,40 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { append([]uint64{ctx_type}, elem.TypeStack...), data, }, nil - }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { - return nil, nil, fmt.Errorf("deserialize pointer unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte) (reflect.Type, *reflect.Value, []byte, error) { + // TODO: figure out how to deal with the case where the pointer value is nil + // In that case need to figure out the type of the pointer to create the nil pointer, which involves continuing to recuse the type stack without parsing any data + if data == nil { + elem_type, _, _, err := DeserializeValue(ctx, type_stack, nil, 1) + if err != nil { + return nil, nil, nil, err + } + return elem_type, nil, nil, nil + } else { + if len(data) < 1 { + return nil, nil, nil, fmt.Errorf("Not enough data to deserialize pointer") + } + pointer_flags := data[0] + if pointer_flags == 0x00 { + _, elem_value, remaining_data, err := DeserializeValue(ctx, type_stack, data[1:], 1) + if err != nil { + return nil, nil, nil, err + } + pointer_value := elem_value[0].Addr() + return pointer_value.Type(), &pointer_value, remaining_data, nil + } else if pointer_flags == 0x01 { + elem_type, _, _, err := DeserializeValue(ctx, type_stack, nil, 1) + if err != nil { + return nil, nil, nil, err + } + + pointer_type := reflect.PointerTo(elem_type) + pointer_value := reflect.New(pointer_type).Elem() + return pointer_type, &pointer_value, data[1:], nil + } else { + return nil, nil, nil, fmt.Errorf("unknown pointer flags: %d", pointer_flags) + } + } }) if err != nil { return nil, err @@ -547,8 +585,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, list_serial.Data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize struct not implemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize struct not implemented") }) if err != nil { return nil, err @@ -565,15 +603,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - if len(value.Data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("invalid length: %d/8", len(value.Data)) + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + if len(data) < 8 { + return nil, nil, nil, fmt.Errorf("invalid length: %d/8", len(data)) } - remaining_data := value.Data[8:] + remaining_data := data[8:] if len(remaining_data) == 0 { remaining_data = nil } - return int(binary.BigEndian.Uint64(value.Data[0:8])), remaining_data, nil + int_val := reflect.ValueOf(binary.BigEndian.Uint64(data[0:8])) + return int_val.Type(), &int_val, remaining_data, nil }) if err != nil { return nil, err @@ -589,8 +628,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize uint32 unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize uint32 unimplemented") }) if err != nil { return nil, err @@ -612,8 +651,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{uint64(ctx_type)}, append(data, []byte(str)...), }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize string unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize string unimplemented") }) if err != nil { return nil, err @@ -654,8 +693,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { append([]uint64{ctx_type}, elem.TypeStack...), data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize array unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize array unimplemented") }) if err != nil { return nil, err @@ -682,8 +721,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { append([]uint64{ctx_type}, type_stack...), data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize interface unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize interface unimplemented") }) if err != nil { return nil, err @@ -757,8 +796,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { type_stack, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize map unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize map unimplemented") }) if err != nil { return nil, err @@ -775,8 +814,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize uint8 unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize uint8 unimplemented") }) if err != nil { return nil, err @@ -793,8 +832,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("deserialize uint64 unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("deserialize uint64 unimplemented") }) if err != nil { return nil, err @@ -837,8 +876,50 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { elem.TypeStack, data, }, nil - }, func(ctx *Context, value SerializedValue)(interface{}, []byte, error){ - return nil, nil, fmt.Errorf("not implemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("not implemented") + }) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(StringError("")), ErrorType, + func(ctx *Context, ctx_type uint64, t reflect.Type, value *reflect.Value) (SerializedValue, error) { + if value == nil { + return SerializedValue{ + []uint64{ctx_type}, + nil, + }, nil + } + + data := make([]byte, 8) + err := value.Interface().(StringError) + str := string(err) + binary.BigEndian.PutUint64(data, uint64(len(str))) + return SerializedValue{ + []uint64{uint64(ctx_type)}, + append(data, []byte(str)...), + }, nil + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("unimplemented") + }) + + err = ctx.RegisterType(reflect.TypeOf(RandID()), NewSerializedType("NodeID"), + func(ctx *Context, ctx_type uint64, t reflect.Type, value *reflect.Value) (SerializedValue, error) { + var id_ser []byte = nil + if value != nil { + var err error = nil + id_ser, err = value.Interface().(NodeID).MarshalBinary() + if err != nil { + return SerializedValue{}, err + } + } + return SerializedValue{ + []uint64{ctx_type}, + id_ser, + }, nil + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return nil, nil, nil, fmt.Errorf("unimplemented") }) if err != nil { return nil, err @@ -855,8 +936,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { - return nil, nil, fmt.Errorf("unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return reflect.TypeOf(Up), nil, nil, fmt.Errorf("unimplemented") }) if err != nil { return nil, err @@ -873,8 +954,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { []uint64{ctx_type}, data, }, nil - }, func(ctx *Context, value SerializedValue) (interface{}, []byte, error) { - return nil, nil, fmt.Errorf("unimplemented") + }, func(ctx *Context, type_stack []uint64, data []byte)(reflect.Type, *reflect.Value, []byte, error){ + return reflect.TypeOf(ReqState(0)), nil, nil, fmt.Errorf("unimplemented") }) if err != nil { return nil, err diff --git a/node.go b/node.go index e720f1d..5f88032 100644 --- a/node.go +++ b/node.go @@ -661,11 +661,11 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } - node_value, err := ParseSerializedValue(ctx, bytes) + value, err := ParseSerializedValue(ctx, bytes) if err != nil { return nil, err } - node_if, remaining, err := DeserializeValue(ctx, node_value, 1) + _, node_val, remaining, err := DeserializeValue(ctx, value.TypeStack, value.Data, 1) if err != nil { return nil, err } @@ -674,9 +674,9 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("%d bytes left after desrializing *Node", len(remaining)) } - node, ok := node_if[0].(*Node) + node, ok := node_val[0].Interface().(*Node) if ok == false { - return nil, fmt.Errorf("Deserialized %+v when expecting *Node", reflect.TypeOf(node_if).Elem()) + return nil, fmt.Errorf("Deserialized %+v when expecting *Node", reflect.TypeOf(node_val).Elem()) } ctx.AddNode(id, node) diff --git a/policy.go b/policy.go index 6ae2cf9..bfd36b6 100644 --- a/policy.go +++ b/policy.go @@ -64,12 +64,12 @@ func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingA return Deny } - reqs_if, _, err := DeserializeValue(ctx, reqs_ser, 1) + _, reqs_if, _, err := DeserializeValue(ctx, reqs_ser.TypeStack, reqs_ser.Data, 1) if err != nil { return Deny } - requirements, ok := reqs_if[0].(map[NodeID]ReqState) + requirements, ok := reqs_if[0].Interface().(map[NodeID]ReqState) if ok == false { return Deny } @@ -113,12 +113,12 @@ func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, s return Deny } - members_if, _, err := DeserializeValue(ctx, members_ser, 1) + _, members_if, _, err := DeserializeValue(ctx, members_ser.TypeStack, members_ser.Data, 1) if err != nil { return Deny } - members, ok := members_if[0].(map[NodeID]string) + members, ok := members_if[0].Interface().(map[NodeID]string) if ok == false { return Deny }