From 7e157068d70c1cf934f8a64557bff35c6cf4fd97 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 30 Mar 2024 22:57:18 -0700 Subject: [PATCH] Removed database update every signal process(need to find a replacement still), updated graphiql cmd, and made lockable more efficient at high numbers of requirements --- cmd/graphiql/main.go | 13 ++- context.go | 92 ++++++++++++----- lockable.go | 53 ++++++---- lockable_test.go | 4 + node.go | 98 ++++++++++++++----- node_test.go | 1 - serialize.go | 228 ++++++++++++++++++++++++++++++++++++------- serialize_test.go | 8 ++ 8 files changed, 385 insertions(+), 112 deletions(-) diff --git a/cmd/graphiql/main.go b/cmd/graphiql/main.go index c604bdb..ce140a4 100644 --- a/cmd/graphiql/main.go +++ b/cmd/graphiql/main.go @@ -17,7 +17,9 @@ func main() { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) check(err) - ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"})) + ctx, err := gv.NewContext(&gv.BadgerDB{ + DB: db, + }, gv.NewConsoleLogger([]string{"test"})) check(err) gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil) @@ -25,13 +27,16 @@ func main() { listener_ext := gv.NewListenerExt(1000) - n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil)) + n1, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil)) check(err) - n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID})) + n2, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID})) check(err) - _, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID})) + n3, err := gv.NewNode(ctx, nil, "LockableNode", 1000, gv.NewLockableExt(nil)) + check(err) + + _, err = gv.NewNode(ctx, nil, "LockableNode", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID, n3.ID})) check(err) for true { diff --git a/context.go b/context.go index f4b6768..901276c 100644 --- a/context.go +++ b/context.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "golang.org/x/exp/constraints" @@ -25,6 +26,7 @@ var ( ) type SerializeFn func(ctx *Context, value reflect.Value, data []byte) (int, error) +type SerializedSizeFn func(ctx *Context, value reflect.Value) (int, error) type DeserializeFn func(ctx *Context, data []byte) (reflect.Value, []byte, error) type NodeFieldInfo struct { @@ -47,6 +49,7 @@ type TypeInfo struct { PostDeserializeIndex int Serialize SerializeFn + SerializedSize SerializedSizeFn Deserialize DeserializeFn } @@ -76,6 +79,12 @@ type InterfaceInfo struct { Fields map[string]graphql.Type } +type ContextNode struct { + Node *Node + Status chan string + Command chan string +} + // A Context stores all the data to run a graphvent process type Context struct { @@ -99,7 +108,9 @@ type Context struct { // Routing map to all the nodes local to this context nodesLock sync.RWMutex - nodes map[NodeID]*Node + nodes map[NodeID]ContextNode + + running atomic.Bool } func gqltype(ctx *Context, t reflect.Type, node_type string) graphql.Type { @@ -793,7 +804,7 @@ func RegisterEnum[E comparable](ctx *Context, str_map map[E]string) error { return nil } -func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, serialize SerializeFn, deserialize DeserializeFn) error { +func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, serialize SerializeFn, sizefn SerializedSizeFn, deserialize DeserializeFn) error { reflect_type := reflect.TypeFor[S]() serialized_type := SerializedTypeFor[S]() @@ -816,6 +827,7 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f Type: gql, Serialize: serialize, + SerializedSize: sizefn, Deserialize: deserialize, } ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] @@ -823,9 +835,13 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f return nil } -func (ctx *Context) AddNode(id NodeID, node *Node) { +func (ctx *Context) AddNode(id NodeID, node *Node, status chan string, command chan string) { ctx.nodesLock.Lock() - ctx.nodes[id] = node + ctx.nodes[id] = ContextNode{ + Node: node, + Status: status, + Command: command, + } ctx.nodesLock.Unlock() } @@ -833,7 +849,7 @@ func (ctx *Context) Node(id NodeID) (*Node, bool) { ctx.nodesLock.RLock() node, exists := ctx.nodes[id] ctx.nodesLock.RUnlock() - return node, exists + return node.Node, exists } func (ctx *Context) Delete(id NodeID) error { @@ -841,7 +857,7 @@ func (ctx *Context) Delete(id NodeID) error { if err != nil { return err } - // TODO: also delete any associated data + // TODO: also delete any associated data from DB return nil } @@ -853,7 +869,14 @@ func (ctx *Context) Unload(id NodeID) error { return fmt.Errorf("%s is not a node in ctx", id) } - err := node.Unload(ctx) + node.Command <- "stop" + returned := <- node.Status + + if returned != "stopped" { + return fmt.Errorf(returned) + } + + err := node.Node.Unload(ctx) delete(ctx.nodes, id) return err } @@ -861,8 +884,18 @@ func (ctx *Context) Unload(id NodeID) error { func (ctx *Context) Stop() { ctx.nodesLock.Lock() for id, node := range(ctx.nodes) { - node.Unload(ctx) - delete(ctx.nodes, id) + node.Command <- "stop" + returned := <- node.Status + + if returned != "stopped" { + ctx.Log.Logf("node", "Node returned %s when commanded to stop", returned) + } else { + err := node.Node.Unload(ctx) + if err != nil { + ctx.Log.Logf("node", "Error unloading %s: %s", id, err) + } + delete(ctx.nodes, id) + } } ctx.nodesLock.Unlock() } @@ -873,14 +906,17 @@ func (ctx *Context) Load(id NodeID) (*Node, error) { return nil, err } - ctx.AddNode(id, node) - started := make(chan error, 1) - go runNode(ctx, node, started) - err = <- started - if err != nil { - return nil, err + status := make(chan string, 0) + command := make(chan string, 0) + go runNode(ctx, node, status, command) + + returned := <- status + if returned != "active" { + return nil, fmt.Errorf(returned) } + ctx.AddNode(id, node, status, command) + return node, nil } @@ -927,7 +963,7 @@ func (ctx *Context) Send(node *Node, messages []SendMsg) error { } func resolveNodeID(val interface{}, p graphql.ResolveParams) (interface{}, error) { - id, ok := p.Source.(NodeID) + id, ok := val.(NodeID) if ok == false { return nil, fmt.Errorf("%+v is not NodeID", p.Source) } @@ -984,7 +1020,7 @@ func NewContext(db Database, log Logger) (*Context, error) { Interfaces: map[string]InterfaceInfo{}, NodeTypes: map[NodeType]NodeInfo{}, - nodes: map[NodeID]*Node{}, + nodes: map[NodeID]ContextNode{}, } var err error @@ -993,6 +1029,8 @@ func NewContext(db Database, log Logger) (*Context, error) { func(ctx *Context, value reflect.Value, data []byte) (int, error) { copy(data, value.Bytes()) return 16, nil + }, func(ctx *Context, value reflect.Value) (int, error) { + return 16, nil }, func(ctx *Context, data []byte) (reflect.Value, []byte, error) { if len(data) < 16 { return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to decode NodeID(got %d, want 16)", len(data)) @@ -1014,6 +1052,8 @@ func NewContext(db Database, log Logger) (*Context, error) { func(ctx *Context, value reflect.Value, data []byte) (int, error) { copy(data, value.Bytes()) return 16, nil + }, func(ctx *Context, value reflect.Value) (int, error) { + return 16, nil }, func(ctx *Context, data []byte) (reflect.Value, []byte, error) { if len(data) < 16 { return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to decode uuid.UUID(got %d, want 16)", len(data)) @@ -1031,12 +1071,12 @@ func NewContext(db Database, log Logger) (*Context, error) { return nil, fmt.Errorf("Failed to register uuid.UUID: %w", err) } - err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType], nil, nil) + err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register NodeType: %w", err) } - err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType], nil, nil) + err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register ExtType: %w", err) } @@ -1051,32 +1091,32 @@ func NewContext(db Database, log Logger) (*Context, error) { return nil, fmt.Errorf("Failed to register NodeType Node: %w", err) } - err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool], nil, nil) + err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register bool: %w", err) } - err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int], nil, nil) + err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register int: %w", err) } - err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32], nil, nil) + err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register uint32: %w", err) } - err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8], nil, nil) + err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register uint8: %w", err) } - err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time], nil, nil) + err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register time.Time: %w", err) } - err = RegisterScalar[string](ctx, identity, coerce[string], astString[string], nil, nil) + err = RegisterScalar[string](ctx, identity, coerce[string], astString[string], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register string: %w", err) } @@ -1086,7 +1126,7 @@ func NewContext(db Database, log Logger) (*Context, error) { return nil, fmt.Errorf("Failed to register ReqState: %w", err) } - err = RegisterScalar[Tag](ctx, identity, coerce[Tag], astString[Tag], nil, nil) + err = RegisterScalar[Tag](ctx, identity, coerce[Tag], astString[Tag], nil, nil, nil) if err != nil { return nil, fmt.Errorf("Failed to register Tag: %w", err) } diff --git a/lockable.go b/lockable.go index 80c12d3..67836a3 100644 --- a/lockable.go +++ b/lockable.go @@ -30,21 +30,43 @@ func (state ReqState) String() string { } } + type LockableExt struct{ State ReqState `gv:"state"` ReqID *uuid.UUID `gv:"req_id"` Owner *NodeID `gv:"owner"` PendingOwner *NodeID `gv:"pending_owner"` Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"` + + Locked map[NodeID]any + Unlocked map[NodeID]any + Waiting WaitMap `gv:"waiting_locks" node:":Lockable"` } +func (ext *LockableExt) PostDeserialize(ctx *Context) error { + ext.Locked = map[NodeID]any{} + ext.Unlocked = map[NodeID]any{} + + for id, state := range(ext.Requirements) { + if state == Unlocked { + ext.Unlocked[id] = nil + } else if state == Locked { + ext.Locked[id] = nil + } + } + return nil +} + func NewLockableExt(requirements []NodeID) *LockableExt { var reqs map[NodeID]ReqState = nil + var unlocked map[NodeID]any = map[NodeID]any{} + if len(requirements) != 0 { reqs = map[NodeID]ReqState{} for _, req := range(requirements) { reqs[req] = Unlocked + unlocked[req] = nil } } @@ -54,6 +76,9 @@ func NewLockableExt(requirements []NodeID) *LockableExt { PendingOwner: nil, Requirements: reqs, Waiting: WaitMap{}, + + Locked: map[NodeID]any{}, + Unlocked: unlocked, } } @@ -295,16 +320,11 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod switch ext.State { case Locking: ext.Requirements[id] = Locked - locked := 0 - for _, req_state := range(ext.Requirements) { - switch req_state { - case Locked: - locked += 1 - } - } + ext.Locked[id] = nil + delete(ext.Unlocked, id) - if locked == len(ext.Requirements) { - ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, locked) + if len(ext.Locked) == len(ext.Requirements) { + ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, len(ext.Locked)) changes = append(changes, "state", "owner", "req_id") ext.State = Locked @@ -313,7 +333,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) ext.ReqID = nil } else { - ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, locked, len(ext.Requirements)) + ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, len(ext.Locked), len(ext.Requirements)) } case AbortingLock: req_state := ext.Requirements[id] @@ -325,6 +345,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod messages = append(messages, SendMsg{id, unlock_signal}) case Unlocking: ext.Requirements[id] = Unlocked + ext.Unlocked[id] = nil + delete(ext.Locked, id) unlocked := 0 for _, req_state := range(ext.Requirements) { @@ -347,15 +369,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod case Unlocking: ext.Requirements[id] = Unlocked - unlocked := 0 - for _, req_state := range(ext.Requirements) { - switch req_state { - case Unlocked: - unlocked += 1 - } - } + ext.Unlocked[id] = Unlocked + delete(ext.Locked, id) - if unlocked == len(ext.Requirements) { + if len(ext.Unlocked) == len(ext.Requirements) { changes = append(changes, "state", "owner", "req_id") messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) diff --git a/lockable_test.go b/lockable_test.go index e495246..16ef676 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -46,6 +46,10 @@ func Test10Lock(t *testing.T) { testLockN(t, 10) } +func Test100Lock(t *testing.T) { + testLockN(t, 100) +} + func Test1000Lock(t *testing.T) { testLockN(t, 1000) } diff --git a/node.go b/node.go index 2227ff5..df82df3 100644 --- a/node.go +++ b/node.go @@ -8,6 +8,7 @@ import ( "reflect" "sync/atomic" "time" + "sync" _ "github.com/dgraph-io/badger/v3" "github.com/google/uuid" @@ -72,7 +73,37 @@ func (q QueuedSignal) String() string { type WaitMap map[uuid.UUID]NodeID -// Default message channel size for nodes +type Queue[T any] struct { + out chan T + in chan T + buffer []T + resize sync.Mutex +} + +func NewQueue[T any](initial int) *Queue[T] { + queue := Queue[T]{ + out: make(chan T, 0), + in: make(chan T, 0), + buffer: make([]T, 0, initial), + } + + go func(queue *Queue[T]) { + }(&queue) + + go func(queue *Queue[T]) { + }(&queue) + + return &queue +} + +func (queue *Queue[T]) Put(value T) error { + return nil +} + +func (queue *Queue[T]) Get(value T) error { + return nil +} + // Nodes represent a group of extensions that can be collectively addressed type Node struct { Key ed25519.PrivateKey `gv:"key"` @@ -155,9 +186,9 @@ func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) { } } -func runNode(ctx *Context, node *Node, started chan error) { +func runNode(ctx *Context, node *Node, status chan string, control chan string) { ctx.Log.Logf("node", "RUN_START: %s", node.ID) - err := nodeLoop(ctx, node, started) + err := nodeLoop(ctx, node, status, control) if err != nil { ctx.Log.Logf("node", "%s runNode err %s", node.ID, err) } @@ -195,7 +226,7 @@ func (node *Node) ReadFields(ctx *Context, fields []string)map[string]any { } // Main Loop for nodes -func nodeLoop(ctx *Context, node *Node, started chan error) error { +func nodeLoop(ctx *Context, node *Node, status chan string, control chan string) error { is_started := node.Active.CompareAndSwap(false, true) if is_started == false { return fmt.Errorf("%s is already started, will not start again", node.ID) @@ -219,13 +250,31 @@ func nodeLoop(ctx *Context, node *Node, started chan error) error { ctx.Log.Logf("node_ext", "Loaded extensions for %s", node.ID) - started <- nil + status <- "active" - run := true - for run == true { + running := true + for running { var signal Signal var source NodeID + + select { + case command := <-control: + switch command { + case "stop": + running = false + case "pause": + status <- "paused" + command := <- control + switch command { + case "resume": + status <- "resumed" + case "stop": + running = false + } + default: + ctx.Log.Logf("node", "Unknown control command %s", command) + } case <-node.TimeoutChan: signal = node.NextSignal.Signal source = node.ID @@ -282,18 +331,17 @@ func nodeLoop(ctx *Context, node *Node, started chan error) error { if stopped == false { panic("BAD_STATE: stopping already stopped node") } + + status <- "stopped" + return nil } func (node *Node) Unload(ctx *Context) error { - if node.Active.Load() { - for _, extension := range(node.Extensions) { - extension.Unload(ctx, node) - } - return nil - } else { - return fmt.Errorf("Node not active") + for _, extension := range(node.Extensions) { + extension.Unload(ctx, node) } + return nil } func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error { @@ -342,13 +390,6 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { } } - if (len(changes) != 0) || node.writeSignalQueue { - write_err := ctx.DB.WriteNodeChanges(ctx, node, changes) - if write_err != nil { - return write_err - } - } - if len(changes) != 0 { status_err := node.QueueChanges(ctx, changes) if status_err != nil { @@ -460,13 +501,16 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size return nil, err } - ctx.AddNode(id, node) - started := make(chan error, 1) - go runNode(ctx, node, started) - err = <- started - if err != nil { - return nil, err + status := make(chan string, 0) + command := make(chan string, 0) + go runNode(ctx, node, status, command) + + returned := <- status + if returned != "active" { + return nil, fmt.Errorf(returned) } + ctx.AddNode(id, node, status, command) + return node, nil } diff --git a/node_test.go b/node_test.go index dd0e75d..cf63b13 100644 --- a/node_test.go +++ b/node_test.go @@ -22,7 +22,6 @@ func TestNodeDB(t *testing.T) { err = ctx.Unload(node.ID) fatalErr(t, err) - ctx.nodes = map[NodeID]*Node{} _, err = ctx.getNode(node.ID) fatalErr(t, err) } diff --git a/serialize.go b/serialize.go index 4973261..0e07f03 100644 --- a/serialize.go +++ b/serialize.go @@ -202,6 +202,152 @@ func Deserialize[T any](ctx *Context, data []byte) (T, error) { return value.Interface().(T), nil } +func SerializedSize(ctx *Context, value reflect.Value) (int, error) { + var sizefn SerializedSizeFn = nil + + info, registered := ctx.Types[value.Type()] + if registered { + sizefn = info.SerializedSize + } + + if sizefn == nil { + switch value.Type().Kind() { + case reflect.Bool: + return 1, nil + + case reflect.Int8: + return 1, nil + case reflect.Int16: + return 2, nil + case reflect.Int32: + return 4, nil + case reflect.Int64: + fallthrough + case reflect.Int: + return 8, nil + + case reflect.Uint8: + return 1, nil + case reflect.Uint16: + return 2, nil + case reflect.Uint32: + return 4, nil + case reflect.Uint64: + fallthrough + case reflect.Uint: + return 8, nil + + case reflect.Float32: + return 4, nil + case reflect.Float64: + return 8, nil + + case reflect.String: + return 8 + value.Len(), nil + + case reflect.Pointer: + if value.IsNil() { + return 1, nil + } else { + elem_len, err := SerializedSize(ctx, value.Elem()) + if err != nil { + return 0, err + } else { + return 1 + elem_len, nil + } + } + + case reflect.Slice: + if value.IsNil() { + return 1, nil + } else { + elem_total := 0 + for i := 0; i < value.Len(); i++ { + elem_len, err := SerializedSize(ctx, value.Index(i)) + if err != nil { + return 0, err + } + elem_total += elem_len + } + return 9 + elem_total, nil + } + + case reflect.Array: + total := 0 + for i := 0; i < value.Len(); i++ { + elem_len, err := SerializedSize(ctx, value.Index(i)) + if err != nil { + return 0, err + } + total += elem_len + } + return total, nil + + case reflect.Map: + if value.IsNil() { + return 1, nil + } else { + key := reflect.New(value.Type().Key()).Elem() + val := reflect.New(value.Type().Elem()).Elem() + iter := value.MapRange() + + total := 0 + for iter.Next() { + key.SetIterKey(iter) + k, err := SerializedSize(ctx, key) + if err != nil { + return 0, err + } + + total += k + + val.SetIterValue(iter) + v, err := SerializedSize(ctx, val) + if err != nil { + return 0, err + } + + total += v + } + + return 9 + total, nil + } + + case reflect.Struct: + if registered == false { + return 0, fmt.Errorf("Can't serialize unregistered struct %s", value.Type()) + } else { + field_total := 0 + for _, field_info := range(info.Fields) { + field_size, err := SerializedSize(ctx, value.FieldByIndex(field_info.Index)) + if err != nil { + return 0, err + } + + field_total += 8 + field_total += field_size + } + + return 8 + field_total, nil + } + + case reflect.Interface: + // TODO get size of TypeStack instead of just using 128 + elem_size, err := SerializedSize(ctx, value.Elem()) + if err != nil { + return 0, err + } + + return 128 + elem_size, nil + + default: + return 0, fmt.Errorf("Don't know how to serialize %s", value.Type()) + } + } else { + return sizefn(ctx, value) + } +} + func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) { var serialize SerializeFn = nil @@ -294,7 +440,6 @@ func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) } case reflect.Array: - data := []byte{} total_written := 0 for i := 0; i < value.Len(); i++ { written, err := SerializeValue(ctx, value.Index(i), data[total_written:]) @@ -306,29 +451,35 @@ func SerializeValue(ctx *Context, value reflect.Value, data []byte) (int, error) return total_written, nil case reflect.Map: - binary.BigEndian.PutUint64(data, uint64(value.Len())) + if value.IsNil() { + data[0] = 0x00 + return 1, nil + } else { + data[0] = 0x01 + binary.BigEndian.PutUint64(data[1:], uint64(value.Len())) - key := reflect.New(value.Type().Key()).Elem() - val := reflect.New(value.Type().Elem()).Elem() - iter := value.MapRange() - total_written := 0 - for iter.Next() { - key.SetIterKey(iter) - val.SetIterValue(iter) + key := reflect.New(value.Type().Key()).Elem() + val := reflect.New(value.Type().Elem()).Elem() + iter := value.MapRange() + total_written := 0 + for iter.Next() { + key.SetIterKey(iter) + val.SetIterValue(iter) - k, err := SerializeValue(ctx, key, data[8+total_written:]) - if err != nil { - return 0, err - } - total_written += k + k, err := SerializeValue(ctx, key, data[9+total_written:]) + if err != nil { + return 0, err + } + total_written += k - v, err := SerializeValue(ctx, val, data[8+total_written:]) - if err != nil { - return 0, err + v, err := SerializeValue(ctx, val, data[9+total_written:]) + if err != nil { + return 0, err + } + total_written += v } - total_written += v + return 9 + total_written, nil } - return 8 + total_written, nil case reflect.Struct: if registered == false { @@ -501,31 +652,36 @@ func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, return value, left, nil case reflect.Map: - len_bytes, left := split(data, 8) - length := int(binary.BigEndian.Uint64(len_bytes)) + flags, after_flags := split(data, 1) + if flags[0] == 0x00 { + return reflect.New(t).Elem(), after_flags, nil + } else { + len_bytes, left := split(after_flags, 8) + length := int(binary.BigEndian.Uint64(len_bytes)) - value := reflect.MakeMapWithSize(t, length) + value := reflect.MakeMapWithSize(t, length) - for i := 0; i < length; i++ { - var key_value reflect.Value - var val_value reflect.Value - var err error + for i := 0; i < length; i++ { + var key_value reflect.Value + var val_value reflect.Value + var err error - key_value, left, err = DeserializeValue(ctx, left, t.Key()) - if err != nil { - return reflect.Value{}, nil, err - } + key_value, left, err = DeserializeValue(ctx, left, t.Key()) + if err != nil { + return reflect.Value{}, nil, err + } - val_value, left, err = DeserializeValue(ctx, left, t.Elem()) - if err != nil { - return reflect.Value{}, nil, err + val_value, left, err = DeserializeValue(ctx, left, t.Elem()) + if err != nil { + return reflect.Value{}, nil, err + } + + value.SetMapIndex(key_value, val_value) } - value.SetMapIndex(key_value, val_value) + return value, left, nil } - return value, left, nil - case reflect.Struct: info, mapped := ctx.Types[t] if mapped { diff --git a/serialize_test.go b/serialize_test.go index 7552f50..f35bda8 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -160,6 +160,14 @@ func TestSerializeValues(t *testing.T) { testSerializeCompare[*int](t, ctx, nil) testSerializeCompare(t, ctx, "string") + testSerialize(t, ctx, map[string]string{ + "Test": "Test", + "key": "String", + "": "", + }) + + testSerialize[map[string]string](t, ctx, nil) + testSerialize(t, ctx, NewListenerExt(10)) node, err := NewNode(ctx, nil, "Node", 100)