diff --git a/cmd/graphiql/main.go b/cmd/graphiql/main.go index 9447f7f..c604bdb 100644 --- a/cmd/graphiql/main.go +++ b/cmd/graphiql/main.go @@ -17,7 +17,7 @@ func main() { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) check(err) - ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "gql"})) + ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"})) check(err) gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil) @@ -25,7 +25,13 @@ func main() { listener_ext := gv.NewListenerExt(1000) - _, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt(nil)) + n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil)) + check(err) + + n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID})) + check(err) + + _, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID})) check(err) for true { diff --git a/context.go b/context.go index 09cf5a6..73e9409 100644 --- a/context.go +++ b/context.go @@ -48,7 +48,7 @@ type TypeInfo struct { Serialize SerializeFn Deserialize DeserializeFn - Resolve func(interface{})(interface{}, error) + Resolve func(interface{},graphql.ResolveParams)(interface{},error) } type ExtensionInfo struct { @@ -65,6 +65,11 @@ type NodeInfo struct { Fields map[string]ExtType } +type InterfaceInfo struct { + Serialized SerializedType + Reflect reflect.Type +} + // A Context stores all the data to run a graphvent process type Context struct { @@ -81,6 +86,9 @@ type Context struct { Extensions map[ExtType]*ExtensionInfo ExtensionTypes map[reflect.Type]*ExtensionInfo + Interfaces map[SerializedType]*InterfaceInfo + InterfaceTypes map[reflect.Type]*InterfaceInfo + // Map between database type hashes and the registered info Nodes map[NodeType]*NodeInfo NodeTypes map[string]*NodeInfo @@ -153,11 +161,15 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro return err } + key_resolve := ctx.GQLResolve(reflect_type.Key(), node_types[0]) + val_type, err := ctx.GQLType(reflect_type.Elem(), node_types[1]) if err != nil { return err } + val_resolve := ctx.GQLResolve(reflect_type.Elem(), node_types[1]) + gql_name := strings.ReplaceAll(reflect_type.String(), ".", "_") gql_name = strings.ReplaceAll(gql_name, "[", "_") gql_name = strings.ReplaceAll(gql_name, "]", "_") @@ -174,7 +186,11 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro return nil, fmt.Errorf("%+v is not Pair", source) } - return source.Key, nil + if key_resolve == nil { + return source.Key, nil + } else { + return key_resolve(source.Key, p) + } }, }, "Value": &graphql.Field{ @@ -185,7 +201,11 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro return nil, fmt.Errorf("%+v is not Pair", source) } - return source.Val, nil + if val_resolve == nil { + return source.Val, nil + } else { + return val_resolve(source.Val, p) + } }, }, }, @@ -199,7 +219,7 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro Serialized: serialized_type, Reflect: reflect_type, Type: gql_map, - Resolve: func(v interface{}) (interface{}, error) { + Resolve: func(v interface{},p graphql.ResolveParams) (interface{}, error) { val := reflect.ValueOf(v) if val.Type() != (reflect_type) { return nil, fmt.Errorf("%s is not %s", val.Type(), reflect_type) @@ -241,16 +261,44 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, subscription := graphql.NewObject(graphql.ObjectConfig{ Name: "Subscription", - Fields: graphql.Fields{ - "Test": &graphql.Field{ - Type: graphql.String, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return "TEST", nil - }, - }, - }, + Fields: graphql.Fields{}, }) + for query_name, query := range(query.Fields()) { + args := graphql.FieldConfigArgument{} + for _, arg := range(query.Args) { + args[arg.Name()] = &graphql.ArgumentConfig{ + Type: arg.Type, + DefaultValue: arg.DefaultValue, + Description: arg.Description(), + } + } + subscription.AddFieldConfig(query_name, &graphql.Field{ + Type: query.Type, + Args: args, + Subscribe: func(p graphql.ResolveParams) (interface{}, error) { + ctx, err := PrepResolve(p) + if err != nil { + return nil, err + } + + c, err := ctx.Ext.AddSubscription(ctx.ID) + if err != nil { + return nil, err + } + + first_result, err := query.Resolve(p) + if err != nil { + return nil, err + } + + c <- first_result + return c, nil + }, + Resolve: query.Resolve, + }) + } + return graphql.NewSchema(graphql.SchemaConfig{ Types: types, Query: query, @@ -409,11 +457,7 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { return err } - gql_resolve, err := ctx.GQLResolve(field_info.Type, field_info.NodeTag) - if err != nil { - return err - } - ctx.Log.Logf("gql", "Adding field %s[%+v] to %s with gql type %+v", field_name, field_info, name, gql_type) + gql_resolve := ctx.GQLResolve(field_info.Type, field_info.NodeTag) gql_interface.AddFieldConfig(field_name, &graphql.Field{ Type: gql_type, @@ -433,7 +477,7 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { } if gql_resolve != nil { - return gql_resolve(node.Data[node_info.Fields[field_name]][field_name]) + return gql_resolve(node.Data[node_info.Fields[field_name]][field_name], p) } else { return node.Data[node_info.Fields[field_name]][field_name], nil } @@ -511,6 +555,8 @@ func RegisterObject[T any](ctx *Context) error { return err } + gql_resolve := ctx.GQLResolve(field.Type, node_tag) + gql.AddFieldConfig(gv_tag, &graphql.Field{ Type: gql_type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { @@ -524,7 +570,11 @@ func RegisterObject[T any](ctx *Context) error { return nil, err } - return value.Interface(), nil + if gql_resolve == nil { + return value.Interface(), nil + } else { + return gql_resolve(value.Interface(), p) + } }, }) } @@ -659,7 +709,7 @@ func astInt[T constraints.Integer](value ast.Value) interface{} { } } -func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}) error { +func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, resolve func(interface{},graphql.ResolveParams)(interface{},error)) error { reflect_type := reflect.TypeFor[S]() serialized_type := SerializedTypeFor[S]() @@ -681,7 +731,7 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f Serialized: serialized_type, Reflect: reflect_type, Type: gql, - Resolve: nil, + Resolve: resolve, } ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] @@ -749,7 +799,7 @@ func (ctx *Context) getNode(id NodeID) (*Node, error) { // Route Messages to dest. Currently only local context routing is supported func (ctx *Context) Send(node *Node, messages []SendMsg) error { for _, msg := range(messages) { - ctx.Log.Logf("signal", "Sending %s -> %+v", msg.Dest, msg) + ctx.Log.Logf("signal", "Sending %s to %s", msg.Signal, msg.Dest) if msg.Dest == ZeroID { panic("Can't send to null ID") } @@ -757,7 +807,7 @@ func (ctx *Context) Send(node *Node, messages []SendMsg) error { if err == nil { select { case target.MsgChan <- RecvMsg{node.ID, msg.Signal}: - ctx.Log.Logf("signal", "Sent %s -> %+v", target.ID, msg) + ctx.Log.Logf("signal", "Sent %s to %s", msg.Signal, msg.Dest) default: buf := make([]byte, 4096) n := runtime.Stack(buf, false) @@ -774,19 +824,38 @@ func (ctx *Context) Send(node *Node, messages []SendMsg) error { return nil } -func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{})(interface{}, error), error) { +func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{},graphql.ResolveParams)(interface{},error)) { info, mapped := ctx.TypeTypes[t] if mapped { - return info.Resolve, nil + return info.Resolve } else { switch t.Kind() { //case reflect.Array: //case reflect.Slice: case reflect.Pointer: return ctx.GQLResolve(t.Elem(), node_type) + default: + return nil } } - return nil, fmt.Errorf("Cannot get resolver for %s", t) +} + +func RegisterInterface[T any](ctx *Context) error { + serialized_type := SerializeType(reflect.TypeFor[T]()) + reflect_type := reflect.TypeFor[T]() + + _, exists := ctx.Interfaces[serialized_type] + if exists == true { + return fmt.Errorf("Interface %+v already exists in context", reflect_type) + } + + ctx.Interfaces[serialized_type] = &InterfaceInfo{ + Serialized: serialized_type, + Reflect: reflect_type, + } + ctx.InterfaceTypes[reflect_type] = ctx.Interfaces[serialized_type] + + return nil } // Create a new Context with the base library content added @@ -801,6 +870,9 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Extensions: map[ExtType]*ExtensionInfo{}, ExtensionTypes: map[reflect.Type]*ExtensionInfo{}, + Interfaces: map[SerializedType]*InterfaceInfo{}, + InterfaceTypes: map[reflect.Type]*InterfaceInfo{}, + Nodes: map[NodeType]*NodeInfo{}, NodeTypes: map[string]*NodeInfo{}, @@ -809,12 +881,29 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { var err error - err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID]) + err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID], func(v interface{}, p graphql.ResolveParams)(interface{}, error) { + id, ok := v.(NodeID) + if ok == false { + return nil, fmt.Errorf("%+v is not NodeID", v) + } + + node, err := ResolveNode(id, p) + if err != nil { + return nil, err + } + + return node, nil + }) + if err != nil { + return nil, err + } + + err = RegisterInterface[Extension](ctx) if err != nil { return nil, err } - err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType]) + err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType], nil) if err != nil { return nil, err } @@ -824,52 +913,52 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool]) + err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool], nil) if err != nil { return nil, err } - err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int]) + err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int], nil) if err != nil { return nil, err } - err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32]) + err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32], nil) if err != nil { return nil, err } - err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8]) + err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8], nil) if err != nil { return nil, err } - err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time]) + err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time], nil) if err != nil { return nil, err } - err = RegisterScalar[string](ctx, identity, coerce[string], astString[string]) + err = RegisterScalar[string](ctx, identity, coerce[string], astString[string], nil) if err != nil { return nil, err } - err = RegisterScalar[EventState](ctx, identity, coerce[EventState], astString[EventState]) + err = RegisterScalar[EventState](ctx, identity, coerce[EventState], astString[EventState], nil) if err != nil { return nil, err } - err = RegisterScalar[ReqState](ctx, identity, coerce[ReqState], astInt[ReqState]) + err = RegisterScalar[ReqState](ctx, identity, coerce[ReqState], astInt[ReqState], nil) if err != nil { return nil, err } - err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID]) + err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID], nil) if err != nil { return nil, err } - err = RegisterScalar[WaitReason](ctx, identity, coerce[WaitReason], astString[WaitReason]) + err = RegisterScalar[WaitReason](ctx, identity, coerce[WaitReason], astString[WaitReason], nil) if err != nil { return nil, err } diff --git a/db.go b/db.go index ae59c21..89ba21e 100644 --- a/db.go +++ b/db.go @@ -1,42 +1,189 @@ package graphvent import ( + "encoding/binary" + "fmt" + badger "github.com/dgraph-io/badger/v3" - "fmt" ) func WriteNodeInit(ctx *Context, node *Node) error { + if node == nil { + return fmt.Errorf("Cannot serialize nil *Node") + } + return ctx.DB.Update(func(tx *badger.Txn) error { - _, err := node.ID.MarshalBinary() + // Get the base key bytes + id_ser, err := node.ID.MarshalBinary() + if err != nil { + return err + } + + // Write Node value + node_val, err := Serialize(ctx, node) + if err != nil { + return err + } + err = tx.Set(id_ser, node_val) + if err != nil { + return err + } + + // Write empty signal queue + sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) + sigqueue_val, err := Serialize(ctx, node.SignalQueue) + if err != nil { + return err + } + err = tx.Set(sigqueue_id, sigqueue_val) if err != nil { return err } - // Write node private key - // Write node type - // Write Node buffer size // Write node extension list + ext_list := []ExtType{} + for ext_type := range(node.Extensions) { + ext_list = append(ext_list, ext_type) + } + ext_list_val, err := Serialize(ctx, ext_list) + if err != nil { + return err + } + ext_list_id := append(id_ser, []byte(" - EXTLIST")...) + err = tx.Set(ext_list_id, ext_list_val) + if err != nil { + return err + } + // For each extension: + for ext_type, ext := range(node.Extensions) { // Write each extension's current value + ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) + ext_val, err := Serialize(ctx, ext) + if err != nil { + return err + } + err = tx.Set(ext_id, ext_val) + } return nil }) } func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { return ctx.DB.Update(func(tx *badger.Txn) error { + // Get the base key bytes + id_ser, err := node.ID.MarshalBinary() + if err != nil { + return err + } + // Write the signal queue if it needs to be written if node.writeSignalQueue { node.writeSignalQueue = false + + sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) + sigqueue_val, err := Serialize(ctx, node.SignalQueue) + if err != nil { + return err + } + err = tx.Set(sigqueue_id, sigqueue_val) + if err != nil { + return err + } } // For each ext in changes - // Write each change + for ext_type := range(changes) { + // Write each ext + ext, exists := node.Extensions[ext_type] + if exists == false { + return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) + } + ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) + ext_ser, err := Serialize(ctx, ext) + if err != nil { + return err + } + + err = tx.Set(ext_id, ext_ser) + if err != nil { + return err + } + } return nil }) } func LoadNode(ctx *Context, id NodeID) (*Node, error) { + var node *Node = nil err := ctx.DB.View(func(tx *badger.Txn) error { + // Get the base key bytes + id_ser, err := id.MarshalBinary() + if err != nil { + return err + } + + // Get the node value + node_item, err := tx.Get(id_ser) + if err != nil { + return err + } + + err = node_item.Value(func(val []byte) error { + node, err = Deserialize[*Node](ctx, val) + return err + }) + + if err != nil { + return nil + } + + // Get the signal queue + sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) + sigqueue_item, err := tx.Get(sigqueue_id) + if err != nil { + return err + } + err = sigqueue_item.Value(func(val []byte) error { + node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val) + return err + }) + if err != nil { + return err + } + + // Get the extension list + ext_list_id := append(id_ser, []byte(" - EXTLIST")...) + ext_list_item, err := tx.Get(ext_list_id) + if err != nil { + return err + } + + var ext_list []ExtType + ext_list_item.Value(func(val []byte) error { + ext_list, err = Deserialize[[]ExtType](ctx, val) + return err + }) + + // Get the extensions + for _, ext_type := range(ext_list) { + ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) + ext_item, err := tx.Get(ext_id) + if err != nil { + return err + } + + var ext Extension + err = ext_item.Value(func(val []byte) error { + ext, err = Deserialize[Extension](ctx, val) + return err + }) + if err != nil { + return err + } + node.Extensions[ext_type] = ext + } + return nil }) @@ -44,5 +191,5 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) { return nil, err } - return nil, fmt.Errorf("NOT_IMPLEMENTED") + return node, nil } diff --git a/event_test.go b/event_test.go index a7d5f63..b2008b6 100644 --- a/event_test.go +++ b/event_test.go @@ -11,6 +11,8 @@ func TestEvent(t *testing.T) { ctx := logTestContext(t, []string{"event", "listener", "listener_debug"}) err := RegisterExtension[TestEventExt](ctx, nil) fatalErr(t, err) + err = RegisterObject[TestEventExt](ctx) + fatalErr(t, err) event_public, event_private, err := ed25519.GenerateKey(rand.Reader) diff --git a/gql_node.go b/gql_node.go index 0f4c3de..9ea187e 100644 --- a/gql_node.go +++ b/gql_node.go @@ -89,15 +89,15 @@ func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) ( return fields, nil } -func ResolveNode(id NodeID, p graphql.ResolveParams) (interface{}, error) { +func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { ctx, err := PrepResolve(p) if err != nil { - return nil, err + return NodeResult{}, err } fields, err := GetResolveFields(id, ctx, p) if err != nil { - return nil, err + return NodeResult{}, err } ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", fields, id) @@ -111,13 +111,13 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (interface{}, error) { }}) if err != nil { ctx.Ext.FreeResponseChannel(signal.ID()) - return nil, err + return NodeResult{}, err } response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID()) ctx.Ext.FreeResponseChannel(signal.ID()) if err != nil { - return nil, err + return NodeResult{}, err } switch response := response.(type) { @@ -147,6 +147,6 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (interface{}, error) { ctx.NodeCache[id] = cache return ctx.NodeCache[id], nil default: - return nil, fmt.Errorf("Bad read response: %+v", response) + return NodeResult{}, fmt.Errorf("Bad read response: %+v", response) } } diff --git a/graph_test.go b/graph_test.go index 51d32d2..5794351 100644 --- a/graph_test.go +++ b/graph_test.go @@ -29,7 +29,7 @@ func logTestContext(t * testing.T, components []string) *Context { ctx, err := NewContext(db, NewConsoleLogger(components)) fatalErr(t, err) - err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()}, map[string]FieldIndex{}) + err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()}) fatalErr(t, err) return ctx diff --git a/lockable_test.go b/lockable_test.go index dbcab30..986a9f3 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -5,17 +5,8 @@ import ( "time" ) -func lockableTestContext(t *testing.T, logs []string) *Context { - ctx := logTestContext(t, logs) - - err := RegisterNodeType(ctx, "Lockable", []ExtType{ExtTypeFor[LockableExt]()}, map[string]FieldIndex{}) - fatalErr(t, err) - - return ctx -} - func TestLink(t *testing.T) { - ctx := lockableTestContext(t, []string{"lockable", "listener"}) + ctx := logTestContext(t, []string{"lockable", "listener"}) l2_listener := NewListenerExt(10) @@ -52,7 +43,7 @@ func TestLink(t *testing.T) { } func Test1000Lock(t *testing.T) { - ctx := lockableTestContext(t, []string{"test"}) + ctx := logTestContext(t, []string{"test"}) NewLockable := func()(*Node) { l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) @@ -88,7 +79,7 @@ func Test1000Lock(t *testing.T) { } func TestLock(t *testing.T) { - ctx := lockableTestContext(t, []string{"test", "lockable"}) + ctx := logTestContext(t, []string{"test", "lockable"}) NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(1000) diff --git a/serialize.go b/serialize.go index e5a1feb..9276db6 100644 --- a/serialize.go +++ b/serialize.go @@ -80,10 +80,10 @@ func GetFieldTag(tag string) FieldTag { return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag)) } -func TypeStack(ctx *Context, t reflect.Type) ([]SerializedType, error) { +func TypeStack(ctx *Context, t reflect.Type) ([]byte, error) { info, registered := ctx.TypeTypes[t] if registered { - return []SerializedType{info.Serialized}, nil + return binary.BigEndian.AppendUint64(nil, uint64(info.Serialized)), nil } else { switch t.Kind() { case reflect.Map: @@ -97,45 +97,47 @@ func TypeStack(ctx *Context, t reflect.Type) ([]SerializedType, error) { return nil, err } - return append([]SerializedType{SerializeType(reflect.Map)}, append(key_stack, elem_stack...)...), nil + return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Map))), append(key_stack, elem_stack...)...), nil case reflect.Pointer: elem_stack, err := TypeStack(ctx, t.Elem()) if err != nil { return nil, err } - return append([]SerializedType{SerializeType(reflect.Pointer)}, elem_stack...), nil + return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Pointer))), elem_stack...), nil case reflect.Slice: elem_stack, err := TypeStack(ctx, t.Elem()) if err != nil { return nil, err } - return append([]SerializedType{SerializeType(reflect.Slice)}, elem_stack...), nil + return append(binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Slice))), elem_stack...), nil case reflect.Array: elem_stack, err := TypeStack(ctx, t.Elem()) if err != nil { return nil, err } - return append([]SerializedType{SerializeType(reflect.Array), SerializedType(t.Len())}, elem_stack...), nil + stack := binary.BigEndian.AppendUint64(nil, uint64(SerializeType(reflect.Array))) + stack = binary.BigEndian.AppendUint64(stack, uint64(t.Len())) + return append(stack, elem_stack...), nil default: return nil, fmt.Errorf("Hit %s, which is not a registered type", t.String()) } } } -func UnwrapStack(ctx *Context, stack []SerializedType) (reflect.Type, []SerializedType, error) { - first := stack[0] - stack = stack[1:] +func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) { + first_bytes, left := split(stack, 8) + first := SerializedType(binary.BigEndian.Uint64(first_bytes)) info, registered := ctx.TypeMap[first] if registered { - return info.Reflect, stack, nil + return info.Reflect, left, nil } else { switch first { case SerializeType(reflect.Map): - key_type, after_key, err := UnwrapStack(ctx, stack) + key_type, after_key, err := UnwrapStack(ctx, left) if err != nil { return nil, nil, err } @@ -147,23 +149,22 @@ func UnwrapStack(ctx *Context, stack []SerializedType) (reflect.Type, []Serializ return reflect.MapOf(key_type, elem_type), after_elem, nil case SerializeType(reflect.Pointer): - elem_type, rest, err := UnwrapStack(ctx, stack) + elem_type, rest, err := UnwrapStack(ctx, left) if err != nil { return nil, nil, err } return reflect.PointerTo(elem_type), rest, nil case SerializeType(reflect.Slice): - elem_type, rest, err := UnwrapStack(ctx, stack) + elem_type, rest, err := UnwrapStack(ctx, left) if err != nil { return nil, nil, err } return reflect.SliceOf(elem_type), rest, nil case SerializeType(reflect.Array): - length := int(stack[0]) + length_bytes, left := split(left, 8) + length := int(binary.BigEndian.Uint64(length_bytes)) - stack = stack[1:] - - elem_type, rest, err := UnwrapStack(ctx, stack) + elem_type, rest, err := UnwrapStack(ctx, left) if err != nil { return nil, nil, err } @@ -176,7 +177,7 @@ func UnwrapStack(ctx *Context, stack []SerializedType) (reflect.Type, []Serializ } func Serialize[T any](ctx *Context, value T) ([]byte, error) { - return serializeValue(ctx, reflect.ValueOf(value)) + return serializeValue(ctx, reflect.ValueOf(&value).Elem()) } func Deserialize[T any](ctx *Context, data []byte) (T, error) { @@ -328,6 +329,17 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { return data, nil } + case reflect.Interface: + data, err := TypeStack(ctx, value.Elem().Type()) + + val_data, err := serializeValue(ctx, value.Elem()) + if err != nil { + return nil, err + } + + data = append(data, val_data...) + + return data, nil default: return nil, fmt.Errorf("Don't know how to serialize %s", value.Type()) } @@ -526,6 +538,22 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t) } + case reflect.Interface: + elem_type, rest, err := UnwrapStack(ctx, data) + if err != nil { + return reflect.Value{}, nil, err + } + + elem_val, left, err := deserializeValue(ctx, rest, elem_type) + if err != nil { + return reflect.Value{}, nil, err + } + + val := reflect.New(t).Elem() + val.Set(elem_val) + + return val, left, nil + default: return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t) } diff --git a/serialize_test.go b/serialize_test.go index bdd26bf..59e2392 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -111,6 +111,8 @@ func testSerialize[T any](t *testing.T, ctx *Context, value T) { func TestSerializeValues(t *testing.T) { ctx := logTestContext(t, []string{"test"}) + testSerialize(t, ctx, Extension(NewLockableExt(nil))) + testSerializeCompare[int8](t, ctx, -64) testSerializeCompare[int16](t, ctx, -64) testSerializeCompare[int32](t, ctx, -64)