From 6942dc02dbff2bf087a41939174cabe594bdc6ae Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 4 Mar 2024 17:30:42 -0700 Subject: [PATCH] Major cleanup --- acl.go | 233 ---------- acl_test.go | 141 ------ context.go | 731 ++++++++---------------------- event.go | 70 +-- extension.go | 2 +- gql.go | 548 +---------------------- gql_node.go | 135 ------ gql_signal.go | 145 ------ graph_test.go | 14 + group.go | 296 ------------ group_test.go | 94 ---- listener.go | 8 +- lockable.go | 204 ++------- lockable_test.go | 56 +-- message.go | 113 +---- node.go | 496 ++------------------ node_test.go | 19 +- policy.go | 139 ------ serialize.go | 1093 +-------------------------------------------- serialize_test.go | 247 ---------- signal.go | 120 +---- 21 files changed, 330 insertions(+), 4574 deletions(-) delete mode 100644 acl.go delete mode 100644 acl_test.go delete mode 100644 gql_signal.go delete mode 100644 group.go delete mode 100644 group_test.go delete mode 100644 policy.go delete mode 100644 serialize_test.go diff --git a/acl.go b/acl.go deleted file mode 100644 index f5b81ff..0000000 --- a/acl.go +++ /dev/null @@ -1,233 +0,0 @@ -package graphvent - -import ( - "github.com/google/uuid" - "slices" - "time" -) - -type ACLSignal struct { - SignalHeader - Principal NodeID `gv:"principal"` - Action Tree `gv:"tree"` -} - -func NewACLSignal(principal NodeID, action Tree) *ACLSignal { - return &ACLSignal{ - SignalHeader: NewSignalHeader(Direct), - Principal: principal, - Action: action, - } -} - -var DefaultACLPolicy = NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[ACLSignal]()): nil, -}) - -func (signal ACLSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[ACLSignal]()): nil, - } -} - -type ACLExt struct { - Policies []Policy `gv:"policies"` - PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` - Pending map[uuid.UUID]PendingACLSignal `gv:"pending"` -} - -func NewACLExt(policies []Policy) *ACLExt { - return &ACLExt{ - Policies: policies, - PendingACLs: map[uuid.UUID]PendingACL{}, - Pending: map[uuid.UUID]PendingACLSignal{}, - } -} - -func (ext *ACLExt) Load(ctx *Context, node *Node) error { - return nil -} - -func (ext *ACLExt) Unload(ctx *Context, node *Node) { -} - -func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { - response, is_response := signal.(ResponseSignal) - if is_response == true { - var messages Messages = nil - var changes = Changes{} - info, waiting := ext.Pending[response.ResponseID()] - if waiting == true { - changes.Add("pending") - delete(ext.Pending, response.ResponseID()) - if response.ID() != info.Timeout { - err := node.DequeueSignal(info.Timeout) - if err != nil { - ctx.Log.Logf("acl", "timeout dequeue error: %s", err) - } - } - - acl_info, found := ext.PendingACLs[info.ID] - if found == true { - acl_info.Counter -= 1 - acl_info.Responses = append(acl_info.Responses, response) - - policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool { - return policy.ID() == info.Policy - }) - - if policy_index == -1 { - ctx.Log.Logf("acl", "pending signal for nonexistent policy") - delete(ext.PendingACLs, info.ID) - err := node.DequeueSignal(acl_info.TimeoutID) - if err != nil { - ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) - } - } else { - if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { - changes.Add("pending_acls") - delete(ext.PendingACLs, info.ID) - ctx.Log.Logf("acl", "Request delayed allow") - messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID)) - err := node.DequeueSignal(acl_info.TimeoutID) - if err != nil { - ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) - } - } else if acl_info.Counter == 0 { - changes.Add("pending_acls") - delete(ext.PendingACLs, info.ID) - ctx.Log.Logf("acl", "Request delayed deny") - messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied")) - err := node.DequeueSignal(acl_info.TimeoutID) - if err != nil { - ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) - } - } else { - node.PendingACLs[info.ID] = acl_info - changes.Add("pending_acls") - } - } - } - } - return messages, changes - } - - var messages Messages = nil - var changes = Changes{} - - switch sig := signal.(type) { - case *ACLSignal: - var acl_messages map[uuid.UUID]Messages = nil - denied := true - for _, policy := range(ext.Policies) { - policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node) - if result == Allow { - denied = false - break - } else if result == Pending { - if len(policy_messages) == 0 { - ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID()) - continue - } else if acl_messages == nil { - acl_messages = map[uuid.UUID]Messages{} - denied = false - } - - acl_messages[policy.ID()] = policy_messages - ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages) - } - } - - if denied == true { - ctx.Log.Logf("acl", "Request denied") - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied")) - } else if acl_messages != nil { - ctx.Log.Logf("acl", "Request pending") - changes.Add("pending") - total_messages := 0 - // TODO: reasonable timeout/configurable - timeout_time := time.Now().Add(time.Second) - for policy_id, policy_messages := range(acl_messages) { - total_messages += len(policy_messages) - for _, message := range(policy_messages) { - timeout_signal := NewTimeoutSignal(message.Signal.ID()) - ext.Pending[message.Signal.ID()] = PendingACLSignal{ - Policy: policy_id, - Timeout: timeout_signal.Id, - ID: sig.Id, - } - node.QueueSignal(timeout_time, timeout_signal) - messages = append(messages, message) - } - } - - acl_timeout := NewACLTimeoutSignal(sig.Id) - node.QueueSignal(timeout_time, acl_timeout) - ext.PendingACLs[sig.Id] = PendingACL{ - Counter: total_messages, - Responses: []ResponseSignal{}, - TimeoutID: acl_timeout.Id, - Action: sig.Action, - Principal: sig.Principal, - - Source: source, - Signal: signal, - } - } else { - ctx.Log.Logf("acl", "Request allowed") - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - } - // Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request - case *ACLTimeoutSignal: - acl_info, exists := ext.PendingACLs[sig.ReqID] - if exists == true { - delete(ext.PendingACLs, sig.ReqID) - changes.Add("pending_acls") - ctx.Log.Logf("acl", "Request timeout deny") - messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout")) - err := node.DequeueSignal(acl_info.TimeoutID) - if err != nil { - ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) - } - } else { - ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl") - } - // Delete from PendingACLs - } - - return messages, changes -} - -type ACLProxyPolicy struct { - PolicyHeader - Proxies []NodeID `gv:"proxies"` -} - -func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy { - return ACLProxyPolicy{ - NewPolicyHeader(), - proxies, - } -} - -func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { - if len(policy.Proxies) == 0 { - return nil, Deny - } - - messages := Messages{} - for _, proxy := range(policy.Proxies) { - messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action)) - } - - return messages, Pending -} - -func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - _, is_success := signal.(*SuccessSignal) - if is_success == true { - return Allow - } - return Deny -} - diff --git a/acl_test.go b/acl_test.go deleted file mode 100644 index 2050260..0000000 --- a/acl_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package graphvent - -import ( - "testing" - "time" - "reflect" - "runtime/debug" -) - -func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){ - response_casted, cast_ok := signal.(S) - if cast_ok == false { - error_signal, is_error := signal.(*ErrorSignal) - if is_error { - t.Log(string(debug.Stack())) - t.Fatal(error_signal.Error) - } - t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal)) - } - - check(response_casted) -} - -func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){ - acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies)) - fatalErr(t, err) - - acl_signal := NewACLSignal(listener.ID, action) - response, _ := testSend(t, ctx, acl_signal, listener, acl_node) - - checkSignal(t, response, check) -} - -func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){ - return func(response *ErrorSignal) { - if response.Error != error_string { - t.Fatalf("Wrong error: %s", response.Error) - } - } -} - -func testSuccess(*SuccessSignal){} - -func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) { - source_listener, err := GetExt[ListenerExt](source) - fatalErr(t, err) - - messages := Messages{} - messages = messages.Add(ctx, destination.ID, source, nil, signal) - fatalErr(t, ctx.Send(messages)) - - response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) - fatalErr(t, err) - - return response, signals -} - -func TestACLBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"}) - - listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100)) - fatalErr(t, err) - - ctx.Log.Logf("test", "testing fail") - testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied")) - - ctx.Log.Logf("test", "testing allow all") - testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess) - - group, err := NewNode(ctx, nil, "Base", 100, []Policy{ - DefaultGroupPolicy, - NewPerNodePolicy(map[NodeID]Tree{ - listener.ID: { - SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil, - SerializedType(SignalTypeFor[AddMemberSignal]()): nil, - }, - }), - }, NewGroupExt(nil)) - fatalErr(t, err) - - ctx.Log.Logf("test", "testing empty groups") - testSendACL(t, ctx, listener, nil, []Policy{ - NewMemberOfPolicy(map[NodeID]map[string]Tree{ - group.ID: { - "test_group": nil, - }, - }), - }, testErrorSignal(t, "acl_denied")) - - ctx.Log.Logf("test", "testing adding group") - add_subgroup_signal := NewAddSubGroupSignal("test_group") - add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group) - checkSignal(t, add_subgroup_response, testSuccess) - - ctx.Log.Logf("test", "testing adding member") - add_member_signal := NewAddMemberSignal("test_group", listener.ID) - add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group) - checkSignal(t, add_member_response, testSuccess) - - ctx.Log.Logf("test", "testing group membership") - testSendACL(t, ctx, listener, nil, []Policy{ - NewMemberOfPolicy(map[NodeID]map[string]Tree{ - group.ID: { - "test_group": nil, - }, - }), - }, testSuccess) - - testSendACL(t, ctx, listener, nil, []Policy{ - NewACLProxyPolicy(nil), - }, testErrorSignal(t, "acl_denied")) - - acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil)) - fatalErr(t, err) - - testSendACL(t, ctx, listener, nil, []Policy{ - NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}), - }, testErrorSignal(t, "acl_denied")) - - acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)})) - fatalErr(t, err) - - testSendACL(t, ctx, listener, nil, []Policy{ - NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}), - }, testSuccess) - - acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, - NewACLExt([]Policy{ - NewMemberOfPolicy(map[NodeID]map[string]Tree{ - group.ID: { - "test_group": nil, - }, - }), - }), - ) - fatalErr(t, err) - - testSendACL(t, ctx, listener, nil, []Policy{ - NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}), - }, testSuccess) -} diff --git a/context.go b/context.go index 3ce389a..f747369 100644 --- a/context.go +++ b/context.go @@ -1,17 +1,17 @@ package graphvent import ( - "crypto/ecdh" - "errors" - "fmt" - "reflect" - "runtime" - "sync" - "time" - "github.com/google/uuid" - "github.com/graphql-go/graphql" - - badger "github.com/dgraph-io/badger/v3" + "crypto/ecdh" + "errors" + "fmt" + "reflect" + "runtime" + "sync" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/ast" + + badger "github.com/dgraph-io/badger/v3" ) var ( @@ -19,161 +19,81 @@ var ( ECDH = ecdh.X25519() ) -type ExtensionInfo struct { - Reflect reflect.Type - Interface graphql.Interface +type TypeInfo struct { + Type graphql.Type +} +type ExtensionInfo struct { + Interface *graphql.Interface + Fields map[string][]int Data interface{} } +type SignalInfo struct { + Type graphql.Type +} + type FieldIndex struct { Extension ExtType Field string } type NodeInfo struct { + GQL *graphql.Object Extensions []ExtType - Policies []Policy Fields map[string]FieldIndex } -type GQLValueConverter func(*Context, interface{})(reflect.Value, error) - -type TypeInfo struct { - Reflect reflect.Type - GQL graphql.Type - - Type SerializedType - TypeSerialize TypeSerializeFn - Serialize SerializeFn - TypeDeserialize TypeDeserializeFn - Deserialize DeserializeFn - - GQLValue GQLValueConverter -} - -type KindInfo struct { - Reflect reflect.Kind - Base reflect.Type - Type SerializedType - TypeSerialize TypeSerializeFn - Serialize SerializeFn - TypeDeserialize TypeDeserializeFn - Deserialize DeserializeFn -} - // A Context stores all the data to run a graphvent process type Context struct { + // DB is the database connection used to load and write nodes DB * badger.DB // Logging interface Log Logger + + // Mapped types + TypeMap map[SerializedType]TypeInfo + TypeTypes map[reflect.Type]SerializedType + // Map between database extension hashes and the registered info Extensions map[ExtType]ExtensionInfo ExtensionTypes map[reflect.Type]ExtType - // Map between databse policy hashes and the registered info - Policies map[PolicyType]reflect.Type - PolicyTypes map[reflect.Type]PolicyType - // Map between serialized signal hashes and the registered info - Signals map[SignalType]reflect.Type - SignalTypes map[reflect.Type]SignalType + // Map between database type hashes and the registered info Nodes map[NodeType]NodeInfo NodeTypes map[string]NodeType - // Map between go types and registered info - Types map[SerializedType]*TypeInfo - TypeReflects map[reflect.Type]*TypeInfo - - Kinds map[reflect.Kind]*KindInfo - KindTypes map[SerializedType]*KindInfo - // Routing map to all the nodes local to this context nodeMapLock sync.RWMutex nodeMap map[NodeID]*Node } -// Register a NodeType to the context, with the list of extensions it requires -func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings map[string]FieldIndex) error { - node_type := NodeTypeFor(name, extensions, mappings) - _, exists := ctx.Nodes[node_type] - if exists == true { - return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type) - } - - ext_found := map[ExtType]bool{} - for _, extension := range(extensions) { - _, in_ctx := ctx.Extensions[extension] - if in_ctx == false { - return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension) - } - - _, duplicate := ext_found[extension] - if duplicate == true { - return fmt.Errorf("Duplicate extension %+v found in extension list", extension) - } - - ext_found[extension] = true - } - - ctx.Nodes[node_type] = NodeInfo{ - Extensions: extensions, - Fields: mappings, - } - ctx.NodeTypes[name] = node_type - - return nil -} - -func RegisterPolicy[P Policy](ctx *Context) error { - reflect_type := reflect.TypeFor[P]() - policy_type := PolicyTypeFor[P]() +func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, error) { + types := []graphql.Type{} - _, exists := ctx.Policies[policy_type] - if exists == true { - return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type) - } + subscription := graphql.NewObject(graphql.ObjectConfig{ - policy_info, err := GetStructInfo(ctx, reflect_type) - if err != nil { - return err - } + }) - err = RegisterType[P](ctx, nil, SerializeStruct(policy_info), nil, DeserializeStruct(policy_info)) - if err != nil { - return err - } - - ctx.Log.Logf("serialize_types", "Registered PolicyType: %+v - %+v", reflect_type, policy_type) - - ctx.Policies[policy_type] = reflect_type - ctx.PolicyTypes[reflect_type] = policy_type - return nil + return graphql.NewSchema(graphql.SchemaConfig{ + Types: types, + Query: query, + Subscription: subscription, + Mutation: mutation, + }) } func RegisterSignal[S Signal](ctx *Context) error { reflect_type := reflect.TypeFor[S]() signal_type := SignalTypeFor[S]() - _, exists := ctx.Signals[signal_type] - if exists == true { - return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type) - } - - signal_info, err := GetStructInfo(ctx, reflect_type) - if err != nil { - return err - } - - err = RegisterType[S](ctx, nil, SerializeStruct(signal_info), nil, DeserializeStruct(signal_info)) + err := RegisterObject[S](ctx) if err != nil { return err } ctx.Log.Logf("serialize_types", "Registered SignalType: %+v - %+v", reflect_type, signal_type) - - ctx.Signals[signal_type] = reflect_type - ctx.SignalTypes[reflect_type] = signal_type return nil } @@ -185,13 +105,51 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i return fmt.Errorf("Cannot register extension %+v of type %+v, type already exists in context", reflect_type, ext_type) } - elem_type := reflect_type.Elem() - elem_info, err := GetStructInfo(ctx, elem_type) - if err != nil { - return err + gql_interface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: reflect_type.String(), + ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { + ctx, ok := p.Context.Value("resolve").(*ResolveContext) + if ok == false { + return nil + } + + node, ok := p.Value.(NodeResult) + if ok == false { + return nil + } + + type_info, type_exists := ctx.Context.Nodes[node.NodeType] + if type_exists == false { + return ctx.Context.Nodes[ctx.Context.NodeTypes["Base"]].GQL + } + + return type_info.GQL + }, + Fields: graphql.Fields{ + "ID": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + + fields := map[string][]int{} + for _, field := range reflect.VisibleFields(reflect.TypeFor[E]()) { + gv_tag, tagged_gv := field.Tag.Lookup("gv") + if tagged_gv { + fields[gv_tag] = field.Index + + type_ser, type_mapped := ctx.TypeTypes[field.Type] + if type_mapped == false { + return fmt.Errorf("Extension %s has field %s of unregistered type %s", reflect_type, gv_tag, field.Type) + } + + gql_interface.AddFieldConfig(gv_tag, &graphql.Field{ + Type: ctx.TypeMap[type_ser].Type, + }) + } } - err = RegisterType[E](ctx, nil, SerializeStruct(elem_info), nil, DeserializeStruct(elem_info)) + err := RegisterObject[E](ctx) if err != nil { return err } @@ -199,111 +157,121 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type) ctx.Extensions[ext_type] = ExtensionInfo{ - Reflect: reflect_type, + Interface: gql_interface, Data: data, + Fields: fields, } ctx.ExtensionTypes[reflect_type] = ext_type return nil } -func RegisterKind(ctx *Context, kind reflect.Kind, base reflect.Type, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { - ctx_type := SerializedKindFor(kind) - _, exists := ctx.Kinds[kind] - if exists == true { - return fmt.Errorf("Cannot register kind %+v, kind already exists in context", kind) - } - _, exists = ctx.KindTypes[ctx_type] +func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings map[string]FieldIndex) error { + node_type := NodeTypeFor(name, extensions, mappings) + _, exists := ctx.Nodes[node_type] if exists == true { - return fmt.Errorf("0x%x is already registered, cannot use for %+v", ctx_type, kind) - } - if deserialize == nil { - return fmt.Errorf("Cannot register field without deserialize function") - } - if serialize == nil { - return fmt.Errorf("Cannot register field without serialize function") + return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type) } - info := KindInfo{ - Reflect: kind, - Type: ctx_type, - Base: base, - TypeSerialize: type_serialize, - Serialize: serialize, - TypeDeserialize: type_deserialize, - Deserialize: deserialize, + ext_found := map[ExtType]bool{} + for _, extension := range(extensions) { + _, in_ctx := ctx.Extensions[extension] + if in_ctx == false { + return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension) + } + + _, duplicate := ext_found[extension] + if duplicate == true { + return fmt.Errorf("Duplicate extension %+v found in extension list", extension) + } + + ext_found[extension] = true } - ctx.KindTypes[ctx_type] = &info - ctx.Kinds[kind] = &info - ctx.Log.Logf("serialize_types", "Registered kind %+v, %+v", kind, ctx_type) + ctx.Nodes[node_type] = NodeInfo{ + Extensions: extensions, + Fields: mappings, + } + ctx.NodeTypes[name] = node_type return nil } -func RegisterType[T any](ctx *Context, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { +func RegisterObject[T any](ctx *Context) error { reflect_type := reflect.TypeFor[T]() - ctx_type := SerializedTypeFor[T]() - - _, exists := ctx.Types[ctx_type] - if exists == true { - return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type) - } - _, exists = ctx.TypeReflects[reflect_type] - if exists == true { - return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type) - } - - if type_serialize == nil || type_deserialize == nil { - kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()] - if kind_registered == true { - if type_serialize == nil { - type_serialize = kind_info.TypeSerialize - } - if type_deserialize == nil { - type_deserialize = kind_info.TypeDeserialize + serialized_type := SerializedTypeFor[T]() + + _, exists := ctx.TypeTypes[reflect_type] + if exists { + return fmt.Errorf("%+v already registered in TypeMap", reflect_type) + } + + gql := graphql.NewObject(graphql.ObjectConfig{ + Name: reflect_type.String(), + IsTypeOf: func(p graphql.IsTypeOfParams) bool { + return reflect_type == reflect.TypeOf(p.Value) + }, + Fields: graphql.Fields{}, + }) + + for _, field := range(reflect.VisibleFields(reflect_type)) { + gv_tag, tagged_gv := field.Tag.Lookup("gv") + if tagged_gv { + field_type, mapped := ctx.TypeTypes[field.Type] + if mapped == false { + return fmt.Errorf("Object %+v has field %s of unknown type %+v", reflect_type, gv_tag, field_type) } + gql.AddFieldConfig(gv_tag, &graphql.Field{ + Type: ctx.TypeMap[field_type].Type, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + val, ok := p.Source.(T) + if ok == false { + return nil, fmt.Errorf("%s is not %s", reflect.TypeOf(p.Source), reflect_type) + } + + value, err := reflect.ValueOf(val).FieldByIndexErr(field.Index) + if err != nil { + return nil, err + } + + return value.Interface(), nil + }, + }) } } - if serialize == nil || deserialize == nil { - kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()] - if kind_registered == false { - return fmt.Errorf("No serialize/deserialize passed and none registered for kind %+v", reflect_type.Kind()) - } else { - if serialize == nil { - serialize = kind_info.Serialize - } - if deserialize == nil { - deserialize = kind_info.Deserialize - } - } + ctx.TypeTypes[reflect_type] = serialized_type + ctx.TypeMap[serialized_type] = TypeInfo{ + Type: gql, } - type_info := TypeInfo{ - Reflect: reflect_type, - Type: ctx_type, - TypeSerialize: type_serialize, - Serialize: serialize, - TypeDeserialize: type_deserialize, - Deserialize: deserialize, + return nil +} + +func RegisterScalar[T any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}) error { + reflect_type := reflect.TypeFor[T]() + serialized_type := SerializedTypeFor[T]() + + _, exists := ctx.TypeTypes[reflect_type] + if exists { + return fmt.Errorf("%+v already registered in TypeMap", reflect_type) } - ctx.Types[ctx_type] = &type_info - ctx.TypeReflects[reflect_type] = &type_info + gql := graphql.NewScalar(graphql.ScalarConfig{ + Name: reflect_type.String(), + Serialize: to_json, + ParseValue: from_json, + ParseLiteral: from_ast, + }) - ctx.Log.Logf("serialize_types", "Registered Type: %+v - %+v", reflect_type, ctx_type) + ctx.TypeTypes[reflect_type] = serialized_type + ctx.TypeMap[serialized_type] = TypeInfo{ + Type: gql, + } return nil } -func RegisterStruct[T any](ctx *Context) error { - struct_info, err := GetStructInfo(ctx, reflect.TypeFor[T]()) - if err != nil { - return err - } - return RegisterType[T](ctx, nil, SerializeStruct(struct_info), nil, DeserializeStruct(struct_info)) -} func (ctx *Context) AddNode(id NodeID, node *Node) { ctx.nodeMapLock.Lock() @@ -364,7 +332,7 @@ func (ctx *Context) getNode(id NodeID) (*Node, error) { } // Route Messages to dest. Currently only local context routing is supported -func (ctx *Context) Send(messages Messages) error { +func (ctx *Context) Send(node *Node, messages []SendMsg) error { for _, msg := range(messages) { ctx.Log.Logf("signal", "Sending %s -> %+v", msg.Dest, msg) if msg.Dest == ZeroID { @@ -373,7 +341,7 @@ func (ctx *Context) Send(messages Messages) error { target, err := ctx.getNode(msg.Dest) if err == nil { select { - case target.MsgChan <- msg: + case target.MsgChan <- RecvMsg{node.ID, msg.Signal}: ctx.Log.Logf("signal", "Sent %s -> %+v", target.ID, msg) default: buf := make([]byte, 4096) @@ -396,255 +364,27 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { ctx := &Context{ DB: db, Log: log, - Policies: map[PolicyType]reflect.Type{}, - PolicyTypes: map[reflect.Type]PolicyType{}, + + TypeMap: map[SerializedType]TypeInfo{}, + TypeTypes: map[reflect.Type]SerializedType{}, + Extensions: map[ExtType]ExtensionInfo{}, ExtensionTypes: map[reflect.Type]ExtType{}, - Signals: map[SignalType]reflect.Type{}, - SignalTypes: map[reflect.Type]SignalType{}, + Nodes: map[NodeType]NodeInfo{}, NodeTypes: map[string]NodeType{}, - Types: map[SerializedType]*TypeInfo{}, - TypeReflects: map[reflect.Type]*TypeInfo{}, - Kinds: map[reflect.Kind]*KindInfo{}, - KindTypes: map[SerializedType]*KindInfo{}, nodeMap: map[NodeID]*Node{}, } var err error - err = RegisterKind(ctx, reflect.Pointer, nil, SerializeTypeElem, SerializePointer, DeserializeTypePointer, DeserializePointer) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Bool, reflect.TypeFor[bool](), nil, SerializeBool, nil, DeserializeBool[bool]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.String, reflect.TypeFor[string](), nil, SerializeString, nil, DeserializeString[string]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Float32, reflect.TypeFor[float32](), nil, SerializeFloat32, nil, DeserializeFloat32[float32]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Float64, reflect.TypeFor[float64](), nil, SerializeFloat64, nil, DeserializeFloat64[float64]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Uint, reflect.TypeFor[uint](), nil, SerializeUint32, nil, DeserializeUint32[uint]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Uint8, reflect.TypeFor[uint8](), nil, SerializeUint8, nil, DeserializeUint8[uint8]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Uint16, reflect.TypeFor[uint16](), nil, SerializeUint16, nil, DeserializeUint16[uint16]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Uint32, reflect.TypeFor[uint32](), nil, SerializeUint32, nil, DeserializeUint32[uint32]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Uint64, reflect.TypeFor[uint64](), nil, SerializeUint64, nil, DeserializeUint64[uint64]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Int, reflect.TypeFor[int](), nil, SerializeInt32, nil, DeserializeUint32[int]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Int8, reflect.TypeFor[int8](), nil, SerializeInt8, nil, DeserializeUint8[int8]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Int16, reflect.TypeFor[int16](), nil, SerializeInt16, nil, DeserializeUint16[int16]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Int32, reflect.TypeFor[int32](), nil, SerializeInt32, nil, DeserializeUint32[int32]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Int64, reflect.TypeFor[int64](), nil, SerializeInt64, nil, DeserializeUint64[int64]) - if err != nil { - return nil, err - } - - err = RegisterType[WaitReason](ctx, nil, nil, nil, DeserializeString[WaitReason]) - if err != nil { - return nil, err - } - - err = RegisterType[EventCommand](ctx, nil, nil, nil, DeserializeString[EventCommand]) - if err != nil { - return nil, err - } - - err = RegisterType[EventState](ctx, nil, nil, nil, DeserializeString[EventState]) - if err != nil { - return nil, err - } - - err = RegisterStruct[WaitInfo](ctx) - if err != nil { - return nil, err - } - - err = RegisterType[time.Duration](ctx, nil, nil, nil, DeserializeUint64[time.Duration]) - if err != nil { - return nil, err - } - - err = RegisterType[time.Time](ctx, nil, SerializeGob, nil, DeserializeGob[time.Time]) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Map, nil, SerializeTypeMap, SerializeMap, DeserializeTypeMap, DeserializeMap) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Array, nil, SerializeTypeArray, SerializeArray, DeserializeTypeArray, DeserializeArray) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Slice, nil, SerializeTypeElem, SerializeSlice, DeserializeTypeSlice, DeserializeSlice) - if err != nil { - return nil, err - } - - err = RegisterKind(ctx, reflect.Interface, reflect.TypeFor[interface{}](), nil, SerializeInterface, nil, DeserializeInterface) - if err != nil { - return nil, err - } - - err = RegisterType[SerializedType](ctx, nil, SerializeUint64, nil, DeserializeUint64[SerializedType]) - if err != nil { - return nil, err - } - - err = RegisterType[Changes](ctx, SerializeTypeStub, SerializeMap, DeserializeTypeStub[Changes], DeserializeMap) - if err != nil { - return nil, err - } - - err = RegisterType[ExtType](ctx, nil, SerializeUint64, nil, DeserializeUint64[ExtType]) - if err != nil { - return nil, err - } - - err = RegisterType[NodeType](ctx, nil, SerializeUint64, nil, DeserializeUint64[NodeType]) - if err != nil { - return nil, err - } - - err = RegisterType[PolicyType](ctx, nil, SerializeUint64, nil, DeserializeUint64[PolicyType]) - if err != nil { - return nil, err - } - - err = RegisterType[NodeID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[NodeID], DeserializeUUID[NodeID]) - if err != nil { - return nil, err - } - - err = RegisterType[uuid.UUID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[uuid.UUID], DeserializeUUID[uuid.UUID]) - if err != nil { - return nil, err - } - - err = RegisterType[SignalDirection](ctx, nil, SerializeUint8, nil, DeserializeUint8[SignalDirection]) - if err != nil { - return nil, err - } - - err = RegisterType[ReqState](ctx, nil, SerializeUint8, nil, DeserializeUint8[ReqState]) - if err != nil { - return nil, err - } - - err = RegisterType[Tree](ctx, SerializeTypeStub, nil, DeserializeTypeStub[Tree], nil) - if err != nil { - return nil, err - } - - err = RegisterType[Extension](ctx, nil, SerializeInterface, nil, DeserializeInterface) - if err != nil { - return nil, err - } - - err = RegisterType[Policy](ctx, nil, SerializeInterface, nil, DeserializeInterface) - if err != nil { - return nil, err - } - - err = RegisterType[Signal](ctx, nil, SerializeInterface, nil, DeserializeInterface) - if err != nil { - return nil, err - } - - err = RegisterStruct[PendingACL](ctx) - if err != nil { - return nil, err - } - - err = RegisterStruct[PendingACLSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterStruct[QueuedSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterStruct[Node](ctx) - if err != nil { - return nil, err - } - - err = RegisterExtension[LockableExt](ctx, nil) - if err != nil { - return nil, err - } err = RegisterExtension[ListenerExt](ctx, nil) if err != nil { return nil, err } - err = RegisterExtension[GroupExt](ctx, nil) - if err != nil { - return nil, err - } - - gql_ctx := NewGQLExtContext() - err = RegisterExtension[GQLExt](ctx, gql_ctx) - if err != nil { - return nil, err - } - - err = RegisterExtension[ACLExt](ctx, nil) + err = RegisterExtension[LockableExt](ctx, nil) if err != nil { return nil, err } @@ -654,127 +394,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = RegisterPolicy[OwnerOfPolicy](ctx) - if err != nil { - return nil, err - } - - err = RegisterPolicy[ParentOfPolicy](ctx) - if err != nil { - return nil, err - } - - err = RegisterPolicy[MemberOfPolicy](ctx) - if err != nil { - return nil, err - } - - err = RegisterPolicy[AllNodesPolicy](ctx) + err = RegisterExtension[GQLExt](ctx, nil) if err != nil { return nil, err } - err = RegisterPolicy[PerNodePolicy](ctx) - if err != nil { - return nil, err - } - - err = RegisterPolicy[ACLProxyPolicy](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[AddSubGroupSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[RemoveSubGroupSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[ACLTimeoutSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[ACLSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[RemoveMemberSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[AddMemberSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[StatusSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[ReadSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[LockSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[TimeoutSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[LinkSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[ErrorSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[SuccessSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[ReadResultSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[EventControlSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterSignal[EventStateSignal](ctx) - if err != nil { - return nil, err - } - - err = RegisterNodeType(ctx, "Base", []ExtType{}, map[string]FieldIndex{}) - if err != nil { - return nil, err - } - - schema, err := BuildSchema(gql_ctx) - if err != nil { - return nil, err - } - - gql_ctx.Schema = schema - return ctx, nil } diff --git a/event.go b/event.go index c8e7c12..6028658 100644 --- a/event.go +++ b/event.go @@ -8,40 +8,6 @@ import ( type EventCommand string type EventState string -type ParentOfPolicy struct { - PolicyHeader - Policy Tree -} - -func NewParentOfPolicy(policy Tree) *ParentOfPolicy { - return &ParentOfPolicy{ - PolicyHeader: NewPolicyHeader(), - Policy: policy, - } -} - -func (policy ParentOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - event_ext, err := GetExt[EventExt](node) - if err != nil { - ctx.Log.Logf("event", "ParentOfPolicy, node not event %s", node.ID) - return nil, Deny - } - - if event_ext.Parent == principal_id { - return nil, policy.Policy.Allows(action) - } - - return nil, Deny -} - -func (policy ParentOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - return Deny -} - -var DefaultEventPolicy = NewParentOfPolicy(Tree{ - SerializedType(SignalTypeFor[EventControlSignal]()): nil, -}) - type EventExt struct { Name string `gv:"name"` State EventState `gv:"state"` @@ -71,19 +37,13 @@ type EventStateSignal struct { Time time.Time `gv:"time"` } -func (signal EventStateSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[StatusSignal]()): nil, - } -} - func (signal EventStateSignal) String() string { return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time) } func NewEventStateSignal(source NodeID, state EventState, t time.Time) *EventStateSignal { return &EventStateSignal{ - SignalHeader: NewSignalHeader(Up), + SignalHeader: NewSignalHeader(), Source: source, State: state, Time: t, @@ -101,19 +61,11 @@ func (signal EventControlSignal) String() string { func NewEventControlSignal(command EventCommand) *EventControlSignal { return &EventControlSignal{ - NewSignalHeader(Direct), + NewSignalHeader(), command, } } -func (signal EventControlSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[EventControlSignal]()): { - Hash("command", string(signal.Command)): nil, - }, - } -} - func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) { if ext.State != state { ext.StateStart = state_start @@ -123,14 +75,10 @@ func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, } } -func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { - var messages Messages = nil +func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} - if signal.Direction() == Up && ext.Parent != node.ID { - messages = messages.Add(ctx, ext.Parent, node, nil, signal) - } - return messages, changes } @@ -165,27 +113,27 @@ var test_event_commands = EventCommandMap{ } -func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { - var messages Messages = nil +func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} switch sig := signal.(type) { case *EventControlSignal: event_ext, err := GetExt[EventExt](node) if err != nil { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_event")) + messages = append(messages, SendMsg{source, NewErrorSignal(sig.Id, "not_event")}) } else { ctx.Log.Logf("event", "%s got %s EventControlSignal while in %s", node.ID, sig.Command, event_ext.State) new_state, error_signal := event_ext.ValidateEventCommand(sig, test_event_commands) if error_signal != nil { - messages = messages.Add(ctx, source, node, nil, error_signal) + messages = append(messages, SendMsg{source, error_signal}) } else { switch sig.Command { case "start": node.QueueSignal(time.Now().Add(ext.Length), NewEventControlSignal("finish")) } event_ext.UpdateState(node, changes, new_state, time.Now()) - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) + messages = append(messages, SendMsg{source, NewSuccessSignal(sig.Id)}) } } } diff --git a/extension.go b/extension.go index 3cef47c..5747e67 100644 --- a/extension.go +++ b/extension.go @@ -7,7 +7,7 @@ import ( // Extensions are data attached to nodes that process signals type Extension interface { // Called to process incoming signals, returning changes and messages to send - Process(*Context, *Node, NodeID, Signal) (Messages, Changes) + Process(*Context, *Node, NodeID, Signal) ([]SendMsg, Changes) // Called when the node is loaded into a context(creation or move), so extension data can be initialized Load(*Context, *Node) error diff --git a/gql.go b/gql.go index db5328b..423bced 100644 --- a/gql.go +++ b/gql.go @@ -73,45 +73,6 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy } } -func NodeInterfaceResolveType(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object { - return func(p graphql.ResolveTypeParams) *graphql.Object { - ctx, ok := p.Context.Value("resolve").(*ResolveContext) - if ok == false { - return nil - } - - node, ok := p.Value.(NodeResult) - if ok == false { - return nil - } - - gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType] - ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type) - if exists == false { - node_type_def, exists := ctx.Context.Nodes[node.NodeType] - if exists == false { - return nil - } else { - for _, ext := range(required_extensions) { - found := false - for _, e := range(node_type_def.Extensions) { - if e == ext { - found = true - break - } - } - if found == false { - return nil - } - } - } - return *default_type - } - - return gql_type - } -} - func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) { resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) if ok == false { @@ -315,9 +276,6 @@ type ResolveContext struct { // Graph Context this resolver is running under Context *Context - // GQL Extension context this resolver is running under - GQLContext *GQLExtContext - // Pointer to the node that's currently processing this request Server *Node @@ -326,9 +284,6 @@ type ResolveContext struct { // Cache of resolved nodes NodeCache map[NodeID]NodeResult - - // Authorization from the user that started this request - Authorization *ClientAuthorization } func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) { @@ -409,141 +364,14 @@ func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (st return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil } -func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) { - joined_b64, err := base64.StdEncoding.DecodeString(auth_base64) - if err != nil { - return nil, err - } - - auth_parts := strings.Split(string(joined_b64), ":") - if len(auth_parts) != 6 { - return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts)) - } - - id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0]) - if err != nil { - return nil, err - } - - iv, err := base64.StdEncoding.DecodeString(auth_parts[1]) - if err != nil { - return nil, err - } - - public_key, err := base64.StdEncoding.DecodeString(auth_parts[2]) - if err != nil { - return nil, err - } - - key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3]) - if err != nil { - return nil, err - } - - start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4]) - if err != nil { - return nil, err - } - - signature, err := base64.StdEncoding.DecodeString(auth_parts[5]) - if err != nil { - return nil, err - } - - var start time.Time - err = start.UnmarshalBinary(start_bytes) - if err != nil { - return nil, err - } - - client_id := ed25519.PublicKey(id_bytes) - if err != nil { - return nil, err - } - - client_point, err := (&edwards25519.Point{}).SetBytes(public_key) - if err != nil { - return nil, err - } - - ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery()) - if err != nil { - return nil, err - } - - h := sha512.Sum512(server_id.Seed()) - ecdh_server, err := ECDH.NewPrivateKey(h[:32]) - if err != nil { - return nil, err - } - - secret, err := ecdh_server.ECDH(ecdh_client) - if err != nil { - return nil, err - } else if len(secret) != 32 { - return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret)) - } - - block, err := aes.NewCipher(secret) - if err != nil { - return nil, err - } - - encrypted_reader := bytes.NewReader(key_encrypted) - stream := cipher.NewOFB(block, iv) - reader := cipher.StreamReader{S: stream, R: encrypted_reader} - var decrypted_key bytes.Buffer - _, err = io.Copy(&decrypted_key, reader) - if err != nil { - return nil, err - } - - session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes()) - digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...) - if ed25519.Verify(client_id, digest, signature) == false { - return nil, fmt.Errorf("Failed to verify digest/signature against client_id") - } - - return &ClientAuthorization{ - AuthInfo: AuthInfo{ - Identity: client_id, - Start: start, - Signature: signature, - }, - Key: session_key, - }, nil -} - -func ValidateAuthorization(auth Authorization, valid time.Duration) error { - // Check that the time + valid < now - // Check that Signature is public_key + start signed with client_id - if auth.Start.Add(valid).Compare(time.Now()) != 1 { - return fmt.Errorf("authorization expired") - } - - time_bytes, err := auth.Start.MarshalBinary() - if err != nil { - return err - } - - digest := append(auth.Key, time_bytes...) - if ed25519.Verify(auth.Identity, digest, auth.Signature) != true { - return fmt.Errorf("verification failed") - } - - return nil -} - func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) { return &ResolveContext{ ID: uuid.New(), Ext: gql_ext, Chans: map[uuid.UUID]chan Signal{}, Context: ctx, - GQLContext: ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext), NodeCache: map[NodeID]NodeResult{}, Server: server, - Authorization: nil, }, nil } @@ -557,13 +385,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err) - json.NewEncoder(w).Encode(GQLUnauthorized("")) - return - } - resolve_context, err := NewResolveContext(ctx, server, gql_ext) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) @@ -571,8 +392,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW return } - resolve_context.Authorization = auth - req_ctx := context.Background() req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) @@ -585,10 +404,10 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW query := GQLPayload{} json.Unmarshal(str, &query) - gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) + schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema) params := graphql.Params{ - Schema: gql_context.Schema, + Schema: schema, Context: req_ctx, RequestString: query.Query, } @@ -716,14 +535,6 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon break } - authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key) - if err != nil { - ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err) - break - } - - resolve_context.Authorization = authorization - conn_state = "ready" err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) if err != nil { @@ -739,9 +550,9 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon } } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) - gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) + schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema) params := graphql.Params{ - Schema: gql_context.Schema, + Schema: schema, Context: req_ctx, RequestString: msg.Payload.Query, } @@ -829,165 +640,10 @@ type Field struct { Field *graphql.Field } -// GQL Specific Context information -type GQLExtContext struct { - // Generated GQL schema - Schema graphql.Schema - - // Custom graphql types, mapped to NodeTypes - NodeTypes map[NodeType]*graphql.Object - Interfaces map[string]*Interface - Fields map[string]Field - - // Schema parameters - Types []graphql.Type - Query *graphql.Object - Mutation *graphql.Object -} - -func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) { - ext_fields := map[ExtType][]string{} - for _, name := range(names) { - switch name { - case "ID": - case "TypeHash": - default: - field, exists := ctx.Fields[name] - if exists == false { - continue - } - - ext, exists := ext_fields[field.Ext] - if exists == false { - ext = []string{} - } - ext = append(ext, field.Name) - ext_fields[field.Ext] = ext - } - } - - return ext_fields, nil -} - -func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { - schemaConfig := graphql.SchemaConfig{ - Types: ctx.Types, - Query: ctx.Query, - Mutation: ctx.Mutation, - } - - return graphql.NewSchema(schemaConfig) -} - -func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, gv_tag string, resolve_fn func(graphql.ResolveParams, *ResolveContext, reflect.Value)(interface{}, error)) error { - if ctx == nil { - return fmt.Errorf("ctx is nil") - } - - if resolve_fn == nil { - return fmt.Errorf("resolve_fn cannot be nil") - } - - _, exists := ctx.Fields[gql_name] - if exists == true { - return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name) - } - - // Resolver has p.Source.(NodeResult) = read result of current node - resolver := func(p graphql.ResolveParams)(interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - node, ok := p.Source.(NodeResult) - if ok == false { - return nil, fmt.Errorf("p.Value is not NodeResult") - } - - ext, ext_exists := node.Data[ext_type] - if ext_exists == false { - return nil, fmt.Errorf("%+v is not in the extensions of the result: %+v", ext_type, node.Data) - } - - val_ser, field_exists := ext[gv_tag] - if field_exists == false { - return nil, fmt.Errorf("%s is not in the fields of %+v in the result for %s - %+v", gv_tag, ext_type, gql_name, node) - } - - if val_ser.TypeStack[0] == SerializedTypeFor[error]() { - return nil, fmt.Errorf(string(val_ser.Data)) - } - - field_type, _, err := DeserializeType(ctx.Context, val_ser.TypeStack) - if err != nil { - return nil, err - } - - field_value, _, err := DeserializeValue(ctx.Context, field_type, val_ser.Data) - if err != nil { - return nil, err - } - - ctx.Context.Log.Logf("gql", "Resolving %+v", field_value) - - return resolve_fn(p, ctx, field_value) - } - - ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{ - Type: gql_type, - Resolve: resolver, - }} - return nil -} - -func GQLInterfaces(ctx *GQLExtContext, interface_names []string) ([]*graphql.Interface, error) { - ret := make([]*graphql.Interface, len(interface_names)) - for i, in := range(interface_names) { - ctx_interface, exists := ctx.Interfaces[in] - if exists == false { - return nil, fmt.Errorf("%s is not in GQLExtContext.Interfaces", in) - } - ret[i] = ctx_interface.Interface - } - - return ret, nil -} - -func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtType, error) { - fields := graphql.Fields{ - "ID": &graphql.Field{ - Type: graphql.String, - Resolve: ResolveNodeID, - }, - "TypeHash": &graphql.Field{ - Type: graphql.String, - Resolve: ResolveNodeTypeHash, - }, - } - - exts := map[ExtType]ExtType{} - ext_list := []ExtType{} - for _, name := range(field_names) { - field, exists := ctx.Fields[name] - if exists == false { - return nil, nil, fmt.Errorf("%s is not in GQLExtContext.Fields", name) - } - fields[name] = field.Field - _, exists = exts[field.Ext] - if exists == false { - ext_list = append(ext_list, field.Ext) - exts[field.Ext] = field.Ext - } - } - - return fields, ext_list, nil -} - type NodeResult struct { NodeID NodeID NodeType NodeType - Data map[ExtType]map[string]SerializedValue + Data map[ExtType]map[string]interface{} } type ListField struct { @@ -1002,193 +658,6 @@ type SelfField struct { ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error) } -func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error { - if interfaces == nil { - return fmt.Errorf("interfaces is nil") - } - - if fields == nil { - return fmt.Errorf("fields is nil") - } - - _, exists := ctx.Interfaces[name] - if exists == true { - return fmt.Errorf("%s is already an interface in ctx", name) - } - - node_interfaces, err := GQLInterfaces(ctx, interfaces) - if err != nil { - return err - } - - node_fields, node_exts, err := GQLFields(ctx, fields) - if err != nil { - return err - } - - ctx_interface := Interface{} - - ctx_interface.Interface = graphql.NewInterface(graphql.InterfaceConfig{ - Name: name, - ResolveType: NodeInterfaceResolveType(node_exts, &ctx_interface.Default), - Fields: node_fields, - }) - ctx_interface.List = graphql.NewList(ctx_interface.Interface) - - for field_name, field := range(self_fields) { - self_field := field - err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { - id, err := self_field.ResolveFn(p, ctx, value) - if err != nil { - return nil, err - } - - if id != nil { - nodes, err := ResolveNodes(ctx, p, []NodeID{*id}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of nodes returned") - } - return nodes[0], nil - } else { - return nil, nil - } - }) - if err != nil { - return err - } - - ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field) - node_fields[field_name] = ctx.Fields[field_name].Field - } - - for field_name, field := range(list_fields) { - list_field := field - resolve_fn := func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { - var zero NodeID - ids, err := list_field.ResolveFn(p, ctx, value) - if err != nil { - return zero, err - } - - nodes, err := ResolveNodes(ctx, p, ids) - if err != nil { - return nil, err - } else if len(nodes) != len(ids) { - return nil, fmt.Errorf("wrong length of nodes returned") - } - return nodes, nil - } - - err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn) - if err != nil { - return err - } - ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field) - node_fields[field_name] = ctx.Fields[field_name].Field - } - - ctx_interface.Default = graphql.NewObject(graphql.ObjectConfig{ - Name: default_name, - Interfaces: append(node_interfaces, ctx_interface.Interface), - IsTypeOf: NodeInterfaceDefaultIsType(node_exts), - Fields: node_fields, - }) - - ctx.Interfaces[name] = &ctx_interface - ctx.Types = append(ctx.Types, ctx_interface.Default) - - return nil -} - -func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, interface_names []string, field_names []string) error { - if field_names == nil { - return fmt.Errorf("fields is nil") - } - - _, exists := ctx.NodeTypes[node_type] - if exists == true { - return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type) - } - - node_interfaces, err := GQLInterfaces(ctx, interface_names) - if err != nil { - return err - } - - gql_fields, _, err := GQLFields(ctx, field_names) - if err != nil { - return err - } - - gql_type := graphql.NewObject(graphql.ObjectConfig{ - Name: name, - Interfaces: node_interfaces, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - node, ok := p.Value.(NodeResult) - if ok == false { - return false - } - - return node.NodeType == node_type - }, - Fields: gql_fields, - }) - - ctx.NodeTypes[node_type] = gql_type - ctx.Types = append(ctx.Types, gql_type) - - return nil -} - -func NewGQLExtContext() *GQLExtContext { - query := graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: graphql.Fields{ - "Test": &graphql.Field{ - Type: graphql.String, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return "Test Data", nil - }, - }, - }, - }) - - mutation := graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: graphql.Fields{ - "Test": &graphql.Field{ - Type: graphql.String, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return "Test Mutation Data", nil - }, - }, - }, - }) - - - context := GQLExtContext{ - Schema: graphql.Schema{}, - Types: []graphql.Type{}, - Query: query, - Mutation: mutation, - NodeTypes: map[NodeType]*graphql.Object{}, - Interfaces: map[string]*Interface{}, - Fields: map[string]Field{}, - } - - schema, err := BuildSchema(&context) - if err != nil { - panic(err) - } - - context.Schema = schema - - return &context -} - type SubscriptionInfo struct { ID uuid.UUID Channel chan interface{} @@ -1295,9 +764,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { return response_chan } -func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { +func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { // Process ReadResultSignalType by forwarding it to the waiting resolver - var changes = Changes{} + var changes Changes = nil + var messages []SendMsg = nil switch sig := signal.(type) { case *SuccessSignal: @@ -1355,7 +825,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa ext.subscriptions_lock.RUnlock() } - return nil, changes + return messages, changes } var ecdsa_curves = map[uint8]elliptic.Curve{ diff --git a/gql_node.go b/gql_node.go index d4b8783..58a9045 100644 --- a/gql_node.go +++ b/gql_node.go @@ -1,11 +1,9 @@ package graphvent import ( - "time" "reflect" "fmt" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" - "github.com/google/uuid" ) func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { @@ -54,136 +52,3 @@ func GetResolveFields(ctx *Context, p graphql.ResolveParams) []string { return names } - -func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([]NodeResult, error) { - fields := GetResolveFields(ctx.Context, p) - ctx.Context.Log.Logf("gql_resolve_node", "RESOLVE_NODES(%+v): %+v", ids, fields) - - resp_channels := map[uuid.UUID]chan Signal{} - indices := map[uuid.UUID]int{} - - // Get a list of fields that will be written - ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) - if err != nil { - return nil, err - } - ctx.Context.Log.Logf("gql_resolve_node", "ACL Fields from request: %+v", ext_fields) - - responses := make([]NodeResult, len(ids)) - - for i, id := range(ids) { - var read_signal *ReadSignal = nil - - node, cached := ctx.NodeCache[id] - if cached == true { - resolve := false - missing_exts := map[ExtType][]string{} - for ext_type, fields := range(ext_fields) { - cached_ext, exists := node.Data[ext_type] - if exists == true { - missing_fields := []string{} - for _, field_name := range(fields) { - _, found := cached_ext[field_name] - if found == false { - missing_fields = append(missing_fields, field_name) - } - } - if len(missing_fields) > 0 { - missing_exts[ext_type] = missing_fields - resolve = true - } - } else { - missing_exts[ext_type] = fields - resolve = true - } - } - - if resolve == true { - read_signal = NewReadSignal(missing_exts) - ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v because of missing fields %+v", id, missing_exts) - } else { - ctx.Context.Log.Logf("gql_resolve_node", "Using cached response for %+v(%d)", id, i) - responses[i] = node - continue - } - } else { - ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v", id) - read_signal = NewReadSignal(ext_fields) - } - // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error - msgs := Messages{} - msgs = msgs.Add(ctx.Context, id, ctx.Server, ctx.Authorization, read_signal) - - response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) - resp_channels[read_signal.ID()] = response_chan - indices[read_signal.ID()] = i - - // TODO: Send all at once instead of creating Messages for each - err = ctx.Context.Send(msgs) - if err != nil { - ctx.Ext.FreeResponseChannel(read_signal.ID()) - return nil, err - } - } - - errors := "" - for sig_id, response_chan := range(resp_channels) { - // Wait for the response, returning an error on timeout - response, other, err := WaitForResponse(response_chan, time.Millisecond*100, sig_id) - if err != nil { - return nil, err - } - ctx.Context.Log.Logf("gql_resolve_node", "GQL node response: %+v", response) - ctx.Context.Log.Logf("gql_resolve_node", "GQL node other messages: %+v", other) - - // for now, just put signals we didn't want back into the 'queue' - for _, other_signal := range(other) { - response_chan <- other_signal - } - - error_signal, is_error := response.(*ErrorSignal) - if is_error { - errors = fmt.Sprintf("%s, %s", errors, error_signal.Error) - continue - } - - read_response, is_read_response := response.(*ReadResultSignal) - if is_read_response == false { - errors = fmt.Sprintf("%s, wrong response type %+v", errors, reflect.TypeOf(response)) - continue - } - - idx := indices[sig_id] - responses[idx] = NodeResult{ - read_response.NodeID, - read_response.NodeType, - read_response.Extensions, - } - - cache, exists := ctx.NodeCache[read_response.NodeID] - if exists == true { - ctx.Context.Log.Logf("gql_resolve_node", "Merging new response with cached: %s, %+v - %+v", read_response.NodeID, cache, read_response.Extensions) - for ext_type, fields := range(read_response.Extensions) { - cached_fields, exists := cache.Data[ext_type] - if exists == false { - cached_fields = map[string]SerializedValue{} - cache.Data[ext_type] = cached_fields - } - for field_name, field_value := range(fields) { - cached_fields[field_name] = field_value - } - } - responses[idx] = cache - } else { - ctx.Context.Log.Logf("gql_resolve_node", "Adding new response to node cache: %s, %+v", read_response.NodeID, read_response.Extensions) - ctx.NodeCache[read_response.NodeID] = responses[idx] - } - } - - if errors != "" { - return nil, fmt.Errorf(errors) - } - - ctx.Context.Log.Logf("gql_resolve_node", "RESOLVED_NODES %+v - %+v", ids, responses) - return responses, nil -} diff --git a/gql_signal.go b/gql_signal.go deleted file mode 100644 index 1f5382c..0000000 --- a/gql_signal.go +++ /dev/null @@ -1,145 +0,0 @@ -package graphvent - -import ( - graphql "github.com/graphql-go/graphql" - "github.com/google/uuid" - "reflect" - "fmt" - "time" -) - -type StructFieldInfo struct { - Name string - Type *TypeInfo - Index []int -} - -func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) { - type_info, mapped := ctx.TypeReflects[field.Type] - if mapped == false { - return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type) - } - - return StructFieldInfo{ - Name: gv_tag, - Type: type_info, - Index: field.Index, - }, nil -} - -func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) { - fmt.Printf("FIELD: %+v\n", fields) - signal_value := reflect.New(signal_type) - - id_field := signal_value.Elem().FieldByIndex(id_index) - id_field.Set(reflect.ValueOf(uuid.New())) - - direction_field := signal_value.Elem().FieldByIndex(direction_index) - direction_field.Set(reflect.ValueOf(Direct)) - - for _, field := range(fields) { - arg, arg_exists := args[field.Name] - if arg_exists == false { - return nil, fmt.Errorf("No arg provided named %s", field.Name) - } - field_value := signal_value.Elem().FieldByIndex(field.Index) - if field_value.CanConvert(field.Type.Reflect) == false { - return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect) - } - value, err := field.Type.GQLValue(ctx, arg) - if err != nil { - return nil, err - } - fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type()) - field_value.Set(value) - } - return signal_value.Interface().(Signal), nil -} - -func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) { - args := graphql.FieldConfigArgument{} - arg_info := []StructFieldInfo{} - var id_index []int = nil - var direction_index []int = nil - - for _, field := range(reflect.VisibleFields(signal_type)) { - gv_tag, tagged_gv := field.Tag.Lookup("gv") - if tagged_gv { - if gv_tag == "id" { - id_index = field.Index - } else if gv_tag == "direction" { - direction_index = field.Index - } else { - _, exists := args[gv_tag] - if exists == true { - return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag) - } else { - info, err := ArgumentInfo(ctx, field, gv_tag) - if err != nil { - return nil, err - } - args[gv_tag] = &graphql.ArgumentConfig{ - } - arg_info = append(arg_info, info) - } - } - } - } - - _, send_exists := args[send_id_key] - if send_exists == false { - args[send_id_key] = &graphql.ArgumentConfig{ - Type: graphql.String, - } - } - - resolve_signal := func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - send_id, err := ExtractID(p, send_id_key) - if err != nil { - return nil, err - } - - signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index) - if err != nil { - return nil, err - } - msgs := Messages{} - msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal) - - response_chan := ctx.Ext.GetResponseChannel(signal.ID()) - err = ctx.Context.Send(msgs) - if err != nil { - ctx.Ext.FreeResponseChannel(signal.ID()) - return nil, err - } - - response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID()) - if err != nil { - ctx.Ext.FreeResponseChannel(signal.ID()) - return nil, err - } - - _, is_success := response.(*SuccessSignal) - if is_success == true { - return "success", nil - } - - error_response, is_error := response.(*ErrorSignal) - if is_error == true { - return "error", fmt.Errorf(error_response.Error) - } - - return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response)) - } - - return &graphql.Field{ - Type: graphql.String, - Args: args, - Resolve: resolve_signal, - }, nil -} diff --git a/graph_test.go b/graph_test.go index 499f9d4..51d32d2 100644 --- a/graph_test.go +++ b/graph_test.go @@ -3,6 +3,7 @@ package graphvent import ( "testing" "runtime/debug" + "time" badger "github.com/dgraph-io/badger/v3" ) @@ -44,3 +45,16 @@ func fatalErr(t * testing.T, err error) { t.Fatal(err) } } + +func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) { + source_listener, err := GetExt[ListenerExt](source) + fatalErr(t, err) + + messages := []SendMsg{{destination.ID, signal}} + fatalErr(t, ctx.Send(source, messages)) + + response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) + fatalErr(t, err) + + return response, signals +} diff --git a/group.go b/group.go deleted file mode 100644 index 84121d8..0000000 --- a/group.go +++ /dev/null @@ -1,296 +0,0 @@ -package graphvent - -import ( - "slices" -) - -type AddSubGroupSignal struct { - SignalHeader - Name string `gv:"name"` -} - -func NewAddSubGroupSignal(name string) *AddSubGroupSignal { - return &AddSubGroupSignal{ - NewSignalHeader(Direct), - name, - } -} - -func (signal AddSubGroupSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[AddSubGroupSignal]()): { - Hash("name", signal.Name): nil, - }, - } -} - -type RemoveSubGroupSignal struct { - SignalHeader - Name string `gv:"name"` -} - -func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal { - return &RemoveSubGroupSignal{ - NewSignalHeader(Direct), - name, - } -} - -func (signal RemoveSubGroupSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): { - Hash("command", signal.Name): nil, - }, - } -} - -type AddMemberSignal struct { - SignalHeader - SubGroup string `gv:"sub_group"` - MemberID NodeID `gv:"member_id"` -} - -type SubGroupGQL struct { - Name string - Members []NodeID -} - -func (signal AddMemberSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[AddMemberSignal]()): { - Hash("sub_group", signal.SubGroup): nil, - }, - } -} - -func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal { - return &AddMemberSignal{ - NewSignalHeader(Direct), - sub_group, - member_id, - } -} - -type RemoveMemberSignal struct { - SignalHeader - SubGroup string `gv:"sub_group"` - MemberID NodeID `gv:"member_id"` -} - -func (signal RemoveMemberSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[RemoveMemberSignal]()): { - Hash("sub_group", signal.SubGroup): nil, - }, - } -} - -func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal { - return &RemoveMemberSignal{ - NewSignalHeader(Direct), - sub_group, - member_id, - } -} - -var DefaultGroupPolicy = NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[ReadSignal]()): { - SerializedType(ExtTypeFor[GroupExt]()): { - SerializedType(GetFieldTag("sub_groups")): nil, - }, - }, -}) - -type SubGroup struct { - Members []NodeID - Permissions Tree -} - -type MemberOfPolicy struct { - PolicyHeader - Groups map[NodeID]map[string]Tree -} - -func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy { - return MemberOfPolicy{ - PolicyHeader: NewPolicyHeader(), - Groups: groups, - } -} - -func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - sig, ok := signal.(*ReadResultSignal) - if ok == false { - return Deny - } - ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions) - - group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()] - if ok == false { - return Deny - } - - sub_groups_ser, ok := group_ext_data["sub_groups"] - if ok == false { - return Deny - } - - sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack) - if err != nil { - ctx.Log.Logf("group", "Type deserialize error: %s", err) - return Deny - } - - sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data) - if err != nil { - ctx.Log.Logf("group", "Value deserialize error: %s", err) - return Deny - } - - ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID) - if ok == false { - return Deny - } - - group, exists := policy.Groups[sig.NodeID] - if exists == false { - return Deny - } - - for sub_group_name, permissions := range(group) { - ext_sub_group, exists := ext_sub_groups[sub_group_name] - if exists == true { - for _, member_id := range(ext_sub_group) { - if member_id == current.Principal { - if permissions.Allows(current.Action) == Allow { - return Allow - } - } - } - } - } - - return Deny -} - -// Send a read signal to Group to check if principal_id is a member of it -func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { - var messages Messages = nil - for group_id, sub_groups := range(policy.Groups) { - if group_id == node.ID { - ext, err := GetExt[GroupExt](node) - if err != nil { - ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err) - } else { - for sub_group_name, permission := range(sub_groups) { - ext_sub_group, exists := ext.SubGroups[sub_group_name] - if exists == true { - for _, member := range(ext_sub_group) { - if member == principal_id { - if permission.Allows(action) == Allow { - return nil, Allow - } - break - } - } - } - } - } - } else { - // Send the read request to the group so that ContinueAllows can parse the response to check membership - messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{ - ExtTypeFor[GroupExt](): {"sub_groups"}, - })) - } - } - if len(messages) > 0 { - return messages, Pending - } else { - return nil, Deny - } -} - -type GroupExt struct { - SubGroups map[string][]NodeID `gv:"sub_groups"` -} - -func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt { - if sub_groups == nil { - sub_groups = map[string][]NodeID{} - } - return &GroupExt{ - SubGroups: sub_groups, - } -} - -func (ext *GroupExt) Load(ctx *Context, node *Node) error { - return nil -} - -func (ext *GroupExt) Unload(ctx *Context, node *Node) { - -} - -func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { - var messages Messages = nil - var changes = Changes{} - - switch sig := signal.(type) { - case *AddMemberSignal: - sub_group, exists := ext.SubGroups[sig.SubGroup] - if exists == false { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup")) - } else { - if slices.Contains(sub_group, sig.MemberID) == true { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member")) - } else { - sub_group = append(sub_group, sig.MemberID) - ext.SubGroups[sig.SubGroup] = sub_group - - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes.Add("sub_groups") - } - } - - case *RemoveMemberSignal: - sub_group, exists := ext.SubGroups[sig.SubGroup] - if exists == false { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup")) - } else { - idx := slices.Index(sub_group, sig.MemberID) - if idx == -1 { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member")) - } else { - sub_group = slices.Delete(sub_group, idx, idx+1) - ext.SubGroups[sig.SubGroup] = sub_group - - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes.Add("sub_groups") - } - } - - case *AddSubGroupSignal: - _, exists := ext.SubGroups[sig.Name] - if exists == true { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup")) - } else { - ext.SubGroups[sig.Name] = []NodeID{} - - changes.Add("sub_groups") - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - } - case *RemoveSubGroupSignal: - _, exists := ext.SubGroups[sig.Name] - if exists == false { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup")) - } else { - delete(ext.SubGroups, sig.Name) - - changes.Add("sub_groups") - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - } - } - - return messages, changes -} - diff --git a/group_test.go b/group_test.go deleted file mode 100644 index 209f7fb..0000000 --- a/group_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package graphvent - -import ( - "testing" - "time" -) - -func TestGroupAdd(t *testing.T) { - ctx := logTestContext(t, []string{"listener", "test"}) - - group_listener := NewListenerExt(10) - group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil)) - fatalErr(t, err) - - add_subgroup_signal := NewAddSubGroupSignal("test_group") - messages := Messages{} - messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal) - fatalErr(t, ctx.Send(messages)) - - resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id) - fatalErr(t, err) - - error_1, is_error := resp_1.(*ErrorSignal) - if is_error { - t.Fatalf("Error returned: %s", error_1.Error) - } - - user_id := RandID() - add_member_signal := NewAddMemberSignal("test_group", user_id) - - messages = Messages{} - messages = messages.Add(ctx, group.ID, group, nil, add_member_signal) - fatalErr(t, ctx.Send(messages)) - - resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id) - fatalErr(t, err) - - error_2, is_error := resp_2.(*ErrorSignal) - if is_error { - t.Fatalf("Error returned: %s", error_2.Error) - } - - read_signal := NewReadSignal(map[ExtType][]string{ - ExtTypeFor[GroupExt](): {"sub_groups"}, - }) - - messages = Messages{} - messages = messages.Add(ctx, group.ID, group, nil, read_signal) - fatalErr(t, ctx.Send(messages)) - - response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id) - fatalErr(t, err) - - read_response := response.(*ReadResultSignal) - - sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"] - - sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack) - fatalErr(t, err) - if len(remaining_types) > 0 { - t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types)) - } - - sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data) - fatalErr(t, err) - if len(remaining) > 0 { - t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types)) - } - - sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID) - - if ok != true { - t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type()) - } - - if len(sub_groups) != 1 { - t.Fatalf("sub_groups wrong length %d", len(sub_groups)) - } - - test_subgroup, exists := sub_groups["test_group"] - if exists == false { - t.Fatal("test_group not in subgroups") - } - - if len(test_subgroup) != 1 { - t.Fatalf("test_group wrong size %d/1", len(test_subgroup)) - } - - if test_subgroup[0] != user_id { - t.Fatalf("sub_groups wrong value %s", test_subgroup[0]) - } - - ctx.Log.Logf("test", "Read Response: %+v", read_response) -} diff --git a/listener.go b/listener.go index ebeab5f..d2b7633 100644 --- a/listener.go +++ b/listener.go @@ -11,17 +11,13 @@ type ListenerExt struct { } func (ext *ListenerExt) Load(ctx *Context, node *Node) error { + ext.Chan = make(chan Signal, ext.Buffer) return nil } func (ext *ListenerExt) Unload(ctx *Context, node *Node) { } -func (ext *ListenerExt) PostDeserialize(ctx *Context) error { - ext.Chan = make(chan Signal, ext.Buffer) - return nil -} - // Create a new listener extension with a given buffer size func NewListenerExt(buffer int) *ListenerExt { return &ListenerExt{ @@ -31,7 +27,7 @@ func NewListenerExt(buffer int) *ListenerExt { } // Send the signal to the channel, logging an overflow if it occurs -func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { +func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal)) ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal) select { diff --git a/lockable.go b/lockable.go index 30e10a4..a7c05f1 100644 --- a/lockable.go +++ b/lockable.go @@ -5,18 +5,6 @@ import ( "time" ) -var AllowParentUnlockPolicy = NewOwnerOfPolicy(Tree{ - SerializedType(SignalTypeFor[LockSignal]()): { - Hash(LockStateBase, "unlock"): nil, - }, -}) - -var AllowAnyLockPolicy = NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[LockSignal]()): { - Hash(LockStateBase, "lock"): nil, - }, -}) - type ReqState byte const ( Unlocked = ReqState(0) @@ -62,17 +50,15 @@ func NewLockableExt(requirements []NodeID) *LockableExt { } func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { - messages := Messages{} signal := NewLockSignal("unlock") - messages = messages.Add(ctx, node.ID, node, nil, signal) - return signal.ID(), ctx.Send(messages) + messages := []SendMsg{{node.ID, signal}} + return signal.ID(), ctx.Send(node, messages) } func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { - messages := Messages{} signal := NewLockSignal("lock") - messages = messages.Add(ctx, node.ID, node, nil, signal) - return signal.ID(), ctx.Send(messages) + messages := []SendMsg{{node.ID, signal}} + return signal.ID(), ctx.Send(node, messages) } func (ext *LockableExt) Load(ctx *Context, node *Node) error { @@ -82,8 +68,8 @@ func (ext *LockableExt) Load(ctx *Context, node *Node) error { func (ext *LockableExt) Unload(ctx *Context, node *Node) { } -func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes Changes = nil info, info_found := node.ProcessResponse(ext.WaitInfos, signal) @@ -126,7 +112,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI ext.Requirements[id] = Unlocking lock_signal := NewLockSignal("unlock") ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) - messages = messages.Add(ctx, id, node, nil, lock_signal) + messages = append(messages, SendMsg{id, lock_signal}) ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID) } } @@ -153,43 +139,43 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI return messages, changes } -func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} if ext.State == Unlocked { switch signal.Action { case "add": _, exists := ext.Requirements[signal.NodeID] if exists == true { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "already_requirement")}) } else { if ext.Requirements == nil { ext.Requirements = map[NodeID]ReqState{} } ext.Requirements[signal.NodeID] = Unlocked changes.Add("requirements") - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) + messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())}) } case "remove": _, exists := ext.Requirements[signal.NodeID] if exists == false { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "can't link: not_requirement")}) } else { delete(ext.Requirements, signal.NodeID) changes.Add("requirements") - messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) + messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())}) } default: - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "unknown_action")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "unknown_action")}) } } else { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")}) } return messages, changes } -func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} if source == node.ID { return messages, changes @@ -218,7 +204,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ext.State = Locked ext.Owner = ext.PendingOwner changes.Add("state", "owner", "requirements") - messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) + messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(ext.PendingID)}) } else { changes.Add("requirements") ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements)) @@ -228,7 +214,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod lock_signal := NewLockSignal("unlock") ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", info.Destination, lock_signal, 100*time.Millisecond) - messages = messages.Add(ctx, info.Destination, node, nil, lock_signal) + messages = append(messages, SendMsg{info.Destination, lock_signal}) ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.Destination, node.ID) } @@ -254,10 +240,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ext.Owner = ext.PendingOwner ext.ReqID = nil changes.Add("state", "owner", "req_id") - messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) + messages = append(messages, SendMsg{previous_owner, NewSuccessSignal(ext.PendingID)}) } else if old_state == AbortingLock { changes.Add("state", "pending_owner") - messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) + messages = append(messages, SendMsg{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked")}) ext.PendingOwner = ext.Owner } } else { @@ -272,8 +258,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod } // Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} switch signal.State { @@ -286,7 +272,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = &new_owner ext.Owner = &new_owner changes.Add("state", "pending_owner", "owner") - messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) + messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())}) } else { ext.State = Locking id := signal.ID() @@ -304,11 +290,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("lock", id, lock_signal, 500*time.Millisecond) ext.Requirements[id] = Locking - messages = messages.Add(ctx, id, node, nil, lock_signal) + messages = append(messages, SendMsg{id, lock_signal}) } } default: - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")}) ctx.Log.Logf("lockable", "Tried to lock %s while %s", node.ID, ext.State) } case "unlock": @@ -319,7 +305,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = nil ext.Owner = nil changes.Add("state", "pending_owner", "owner") - messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) + messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())}) } else if source == *ext.Owner { ext.State = Unlocking id := signal.ID() @@ -336,11 +322,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) ext.Requirements[id] = Unlocking - messages = messages.Add(ctx, id, node, nil, lock_signal) + messages = append(messages, SendMsg{id, lock_signal}) } } } else { - messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_locked")) + messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_locked")}) } default: ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) @@ -348,8 +334,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID return messages, changes } -func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} wait_info, found := node.ProcessResponse(ext.WaitInfos, signal) @@ -380,7 +366,7 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod ext.Requirements[id] = Unlocking lock_signal := NewLockSignal("unlock") ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) - messages = messages.Add(ctx, id, node, nil, lock_signal) + messages = append(messages, SendMsg{id, lock_signal}) ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID) } } @@ -405,124 +391,32 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod return messages, changes } -// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes +// LockableExts process status signals by forwarding them to it's owner // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state -func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { - var messages Messages = nil +func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) { + var messages []SendMsg = nil var changes = Changes{} - switch signal.Direction() { - case Up: + switch sig := signal.(type) { + case *StatusSignal: if ext.Owner != nil { if *ext.Owner != node.ID { - messages = messages.Add(ctx, *ext.Owner, node, nil, signal) + messages = append(messages, SendMsg{*ext.Owner, signal}) } } - - case Down: - for requirement := range(ext.Requirements) { - messages = messages.Add(ctx, requirement, node, nil, signal) - } - - case Direct: - switch sig := signal.(type) { - case *LinkSignal: - messages, changes = ext.HandleLinkSignal(ctx, node, source, sig) - case *LockSignal: - messages, changes = ext.HandleLockSignal(ctx, node, source, sig) - case *ErrorSignal: - messages, changes = ext.HandleErrorSignal(ctx, node, source, sig) - case *SuccessSignal: - messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig) - case *TimeoutSignal: - messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig) - default: - } + case *LinkSignal: + messages, changes = ext.HandleLinkSignal(ctx, node, source, sig) + case *LockSignal: + messages, changes = ext.HandleLockSignal(ctx, node, source, sig) + case *ErrorSignal: + messages, changes = ext.HandleErrorSignal(ctx, node, source, sig) + case *SuccessSignal: + messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig) + case *TimeoutSignal: + messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig) default: } - return messages, changes -} - -type OwnerOfPolicy struct { - PolicyHeader - Rules Tree `gv:"rules"` -} - -func NewOwnerOfPolicy(rules Tree) OwnerOfPolicy { - return OwnerOfPolicy{ - PolicyHeader: NewPolicyHeader(), - Rules: rules, - } -} - -func (policy OwnerOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - return Deny -} - -func (policy OwnerOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - l_ext, err := GetExt[LockableExt](node) - if err != nil { - ctx.Log.Logf("lockable", "OwnerOfPolicy.Allows called on node without LockableExt") - return nil, Deny - } - if l_ext.Owner == nil { - return nil, Deny - } - - if principal_id == *l_ext.Owner { - return nil, Allow - } - - return nil, Deny -} - -type RequirementOfPolicy struct { - PerNodePolicy -} - -func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy { - return RequirementOfPolicy { - PerNodePolicy: NewPerNodePolicy(dep_rules), - } + return messages, changes } -func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - sig, ok := signal.(*ReadResultSignal) - if ok == false { - return Deny - } - - ext, ok := sig.Extensions[ExtTypeFor[LockableExt]()] - if ok == false { - return Deny - } - - reqs_ser, ok := ext["requirements"] - if ok == false { - return Deny - } - - reqs_type, _, err := DeserializeType(ctx, reqs_ser.TypeStack) - if err != nil { - return Deny - } - - reqs_if, _, err := DeserializeValue(ctx, reqs_type, reqs_ser.Data) - if err != nil { - return Deny - } - - requirements, ok := reqs_if.Interface().(map[NodeID]ReqState) - if ok == false { - return Deny - } - - for req := range(requirements) { - if req == current.Principal { - return policy.NodeRules[sig.NodeID].Allows(current.Action) - } - } - - return Deny -} diff --git a/lockable_test.go b/lockable_test.go index fe90655..8febf74 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -3,8 +3,6 @@ package graphvent import ( "testing" "time" - "crypto/ed25519" - "crypto/rand" ) func lockableTestContext(t *testing.T, logs []string) *Context { @@ -19,32 +17,19 @@ func lockableTestContext(t *testing.T, logs []string) *Context { func TestLink(t *testing.T) { ctx := lockableTestContext(t, []string{"lockable", "listener"}) - l1_pub, l1_key, err := ed25519.GenerateKey(rand.Reader) - fatalErr(t, err) - l1_id := KeyID(l1_pub) - policy := NewPerNodePolicy(map[NodeID]Tree{ - l1_id: nil, - }) l2_listener := NewListenerExt(10) - l2, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, - l2_listener, - NewLockableExt(nil), - ) + l2, err := NewNode(ctx, nil, "Lockable", 10, l2_listener, NewLockableExt(nil)) fatalErr(t, err) l1_lockable := NewLockableExt(nil) l1_listener := NewListenerExt(10) - l1, err := NewNode(ctx, l1_key, "Lockable", 10, nil, - l1_listener, - l1_lockable, - ) + l1, err := NewNode(ctx, nil, "Lockable", 10, l1_listener, l1_lockable) fatalErr(t, err) - msgs := Messages{} link_signal := NewLinkSignal("add", l2.ID) - msgs = msgs.Add(ctx, l1.ID, l1, nil, link_signal) - err = ctx.Send(msgs) + msgs := []SendMsg{{l1.ID, link_signal}} + err = ctx.Send(l1, msgs) fatalErr(t, err) _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID()) @@ -57,10 +42,9 @@ func TestLink(t *testing.T) { t.Fatalf("l2 in bad requirement state in l1: %+v", state) } - msgs = Messages{} unlink_signal := NewLinkSignal("remove", l2.ID) - msgs = msgs.Add(ctx, l1.ID, l1, nil, unlink_signal) - err = ctx.Send(msgs) + msgs = []SendMsg{{l1.ID, unlink_signal}} + err = ctx.Send(l1, msgs) fatalErr(t, err) _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID()) @@ -70,18 +54,8 @@ func TestLink(t *testing.T) { func Test1000Lock(t *testing.T) { ctx := lockableTestContext(t, []string{"test", "lockable"}) - l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader) - fatalErr(t, err) - listener_id := KeyID(l_pub) - child_policy := NewPerNodePolicy(map[NodeID]Tree{ - listener_id: { - SerializedType(SignalTypeFor[LockSignal]()): nil, - }, - }) NewLockable := func()(*Node) { - l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{child_policy}, - NewLockableExt(nil), - ) + l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) fatalErr(t, err) return l } @@ -93,15 +67,8 @@ func Test1000Lock(t *testing.T) { } ctx.Log.Logf("test", "CREATED_1000") - l_policy := NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[LockSignal]()): nil, - }) - listener := NewListenerExt(5000) - node, err := NewNode(ctx, listener_key, "Lockable", 5000, []Policy{l_policy}, - listener, - NewLockableExt(reqs), - ) + node, err := NewNode(ctx, nil, "Lockable", 5000, listener, NewLockableExt(reqs)) fatalErr(t, err) ctx.Log.Logf("test", "CREATED_LISTENER") @@ -123,14 +90,9 @@ func Test1000Lock(t *testing.T) { func TestLock(t *testing.T) { ctx := lockableTestContext(t, []string{"test", "lockable"}) - policy := NewAllNodesPolicy(nil) - NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(1000) - l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, - listener, - NewLockableExt(reqs), - ) + l, err := NewNode(ctx, nil, "Lockable", 10, listener, NewLockableExt(reqs)) fatalErr(t, err) return l, listener } diff --git a/message.go b/message.go index 7c87178..bf5fc1e 100644 --- a/message.go +++ b/message.go @@ -1,114 +1,11 @@ package graphvent -import ( - "time" - "crypto/ed25519" - "crypto/rand" - "crypto" -) - -type AuthInfo struct { - // The Node that issued the authorization - Identity ed25519.PublicKey - - // Time the authorization was generated - Start time.Time - - // Signature of Start + Principal with Identity private key - Signature []byte -} - -type AuthorizationToken struct { - AuthInfo - - // The private key generated by the client, encrypted with the servers public key - KeyEncrypted []byte -} - -type ClientAuthorization struct { - AuthInfo - - // The private key generated by the client - Key ed25519.PrivateKey -} - -// Authorization structs can be passed in a message that originated from a different node than the sender -type Authorization struct { - AuthInfo - - // The public key generated for this authorization - Key ed25519.PublicKey -} - -type Message struct { - Dest NodeID - Source ed25519.PublicKey - - Authorization *Authorization - +type SendMsg struct { + Dest NodeID Signal Signal - Signature []byte -} - -type Messages []*Message -func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages { - msg, err := NewMessage(ctx, dest, source, authorization, signal) - if err != nil { - panic(err) - } else { - msgs = append(msgs, msg) - } - return msgs -} - -func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages { - messages := Messages{} - for _, signal := range(signals) { - messages = messages.Add(ctx, dest, source, authorization, signal) - } - return messages } -func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) { - signal_ser, err := SerializeAny(ctx, signal) - if err != nil { - return nil, err - } - - signal_chunks, err := signal_ser.Chunks() - if err != nil { - return nil, err - } - - dest_ser, err := dest.MarshalBinary() - if err != nil { - return nil, err - } - source_ser, err := source.ID.MarshalBinary() - if err != nil { - return nil, err - } - sig_data := append(dest_ser, source_ser...) - sig_data = append(sig_data, signal_chunks.Slice()...) - var message_auth *Authorization = nil - if authorization != nil { - sig_data = append(sig_data, authorization.Signature...) - message_auth = &Authorization{ - authorization.AuthInfo, - authorization.Key.Public().(ed25519.PublicKey), - } - } - - sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) - if err != nil { - return nil, err - } - - return &Message{ - Dest: dest, - Source: source.Key.Public().(ed25519.PublicKey), - Authorization: message_auth, - Signal: signal, - Signature: sig, - }, nil +type RecvMsg struct { + Source NodeID + Signal Signal } diff --git a/node.go b/node.go index c6ba523..c9a764e 100644 --- a/node.go +++ b/node.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - badger "github.com/dgraph-io/badger/v3" + _ "github.com/dgraph-io/badger/v3" "github.com/google/uuid" ) @@ -57,24 +57,6 @@ func (q QueuedSignal) String() string { return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time) } -type PendingACL struct { - Counter int - Responses []ResponseSignal - - TimeoutID uuid.UUID - Action Tree - Principal NodeID - - Signal Signal - Source NodeID -} - -type PendingACLSignal struct { - Policy uuid.UUID - Timeout uuid.UUID - ID uuid.UUID -} - // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { @@ -84,13 +66,8 @@ type Node struct { // TODO: move each extension to it's own db key, and extend changes to notify which extension was changed Extensions map[ExtType]Extension - Policies []Policy `gv:"policies"` - - PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` - PendingACLSignals map[uuid.UUID]PendingACLSignal `gv:"pending_signal"` - // Channel for this node to receive messages from the Context - MsgChan chan *Message + MsgChan chan RecvMsg // Size of MsgChan BufferSize uint32 `gv:"buffer_size"` // Channel for this node to process delayed signals @@ -110,34 +87,11 @@ func (node *Node) PostDeserialize(ctx *Context) error { public := node.Key.Public().(ed25519.PublicKey) node.ID = KeyID(public) - node.MsgChan = make(chan *Message, node.BufferSize) + node.MsgChan = make(chan RecvMsg, node.BufferSize) return nil } -type RuleResult int -const ( - Allow RuleResult = iota - Deny - Pending -) - -func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[uuid.UUID]Messages, RuleResult) { - pends := map[uuid.UUID]Messages{} - for _, policy := range(node.Policies) { - msgs, resp := policy.Allows(ctx, principal_id, action, node) - if resp == Allow { - return nil, Allow - } else if resp == Pending { - pends[policy.ID()] = msgs - } - } - if len(pends) != 0 { - return pends, Pending - } - return nil, Deny -} - type WaitReason string type WaitInfo struct { Destination NodeID `gv:"destination"` @@ -250,37 +204,23 @@ func (err StringError) MarshalBinary() ([]byte, error) { return []byte(string(err)), nil } -func NewErrorField(fstring string, args ...interface{}) SerializedValue { - str := StringError(fmt.Sprintf(fstring, args...)) - str_ser, err := str.MarshalBinary() - if err != nil { - panic(err) - } - return SerializedValue{ - TypeStack: []SerializedType{SerializedTypeFor[error]()}, - Data: str_ser, - } -} - -func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue { +func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]any { ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID) - exts := map[ExtType]map[string]SerializedValue{} + exts := map[ExtType]map[string]any{} for ext_type, field_reqs := range(reqs) { - fields := map[string]SerializedValue{} - for _, req := range(field_reqs) { - ext, exists := node.Extensions[ext_type] - if exists == false { - fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type) - } else { - f, err := SerializeField(ctx, ext, req) - if err != nil { - fields[req] = NewErrorField(err.Error()) + ext_info, ext_known := ctx.Extensions[ext_type] + if ext_known { + fields := map[string]any{} + for _, req := range(field_reqs) { + ext, exists := node.Extensions[ext_type] + if exists == false { + fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type) } else { - fields[req] = f + fields[req] = reflect.ValueOf(ext).FieldByIndex(ext_info.Fields[req]).Interface() } } + exts[ext_type] = fields } - exts[ext_type] = fields } return exts } @@ -306,92 +246,8 @@ func nodeLoop(ctx *Context, node *Node) error { var source NodeID select { case msg := <- node.MsgChan: - ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal) - signal_ser, err := SerializeAny(ctx, msg.Signal) - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal) - } - chunks, err := signal_ser.Chunks() - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser) - continue - } - - dst_id_ser, err := msg.Dest.MarshalBinary() - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err) - continue - } - src_id_ser, err := KeyID(msg.Source).MarshalBinary() - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err) - continue - } - sig_data := append(dst_id_ser, src_id_ser...) - sig_data = append(sig_data, chunks.Slice()...) - if msg.Authorization != nil { - sig_data = append(sig_data, msg.Authorization.Signature...) - } - validated := ed25519.Verify(msg.Source, sig_data, msg.Signature) - if validated == false { - ctx.Log.Logf("signal_verify", "SIGNAL_VERIFY_ERR: %s - %s", node.ID, reflect.TypeOf(msg.Signal)) - continue - } - - var princ_id NodeID - if msg.Authorization == nil { - princ_id = KeyID(msg.Source) - } else { - err := ValidateAuthorization(*msg.Authorization, time.Hour) - if err != nil { - ctx.Log.Logf("node", "Authorization validation failed: %s", err) - continue - } - princ_id = KeyID(msg.Authorization.Identity) - } - if princ_id != node.ID { - pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission()) - if resp == Deny { - ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal) - ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) - msgs := Messages{} - msgs = msgs.Add(ctx, KeyID(msg.Source), node, nil, NewErrorSignal(msg.Signal.ID(), "acl denied")) - ctx.Send(msgs) - continue - } else if resp == Pending { - ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends) - timeout_signal := NewACLTimeoutSignal(msg.Signal.ID()) - node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal) - msgs := Messages{} - for policy_type, sigs := range(pends) { - for _, m := range(sigs) { - msgs = append(msgs, m) - timeout_signal := NewTimeoutSignal(m.Signal.ID()) - node.QueueSignal(time.Now().Add(time.Second), timeout_signal) - node.PendingACLSignals[m.Signal.ID()] = PendingACLSignal{policy_type, timeout_signal.Id, msg.Signal.ID()} - } - } - node.PendingACLs[msg.Signal.ID()] = PendingACL{ - Counter: len(msgs), - TimeoutID: timeout_signal.ID(), - Action: msg.Signal.Permission(), - Principal: princ_id, - Responses: []ResponseSignal{}, - Signal: msg.Signal, - Source: KeyID(msg.Source), - } - ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs) - ctx.Send(msgs) - continue - } else if resp == Allow { - ctx.Log.Logf("policy", "SIGNAL_POLICY_ALLOW: %s->%s - %s", princ_id, node.ID, reflect.TypeOf(msg.Signal)) - } - } else { - ctx.Log.Logf("policy", "SIGNAL_POLICY_SELF: %s - %s", node.ID, reflect.TypeOf(msg.Signal)) - } - signal = msg.Signal - source = KeyID(msg.Source) + source = msg.Source case <-node.TimeoutChan: signal = node.NextSignal.Signal @@ -425,68 +281,12 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) - response, ok := signal.(ResponseSignal) - if ok == true { - info, waiting := node.PendingACLSignals[response.ResponseID()] - if waiting == true { - delete(node.PendingACLSignals, response.ResponseID()) - ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) - - req_info, exists := node.PendingACLs[info.ID] - if exists == true { - req_info.Counter -= 1 - req_info.Responses = append(req_info.Responses, response) - - idx := -1 - for i, p := range(node.Policies) { - if p.ID() == info.Policy { - idx = i - break - } - } - if idx == -1 { - ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) - delete(node.PendingACLs, info.ID) - } else { - allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) - if allowed == Allow { - ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) - signal = req_info.Signal - source = req_info.Source - err := node.DequeueSignal(req_info.TimeoutID) - if err != nil { - ctx.Log.Logf("node", "dequeue error: %s", err) - } - delete(node.PendingACLs, info.ID) - } else if req_info.Counter == 0 { - ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) - // Send the denied response - msgs := Messages{} - msgs = msgs.Add(ctx, req_info.Source, node, nil, NewErrorSignal(req_info.Signal.ID(), "acl_denied")) - err := ctx.Send(msgs) - if err != nil { - ctx.Log.Logf("signal", "SEND_ERR: %s", err) - } - err = node.DequeueSignal(req_info.TimeoutID) - if err != nil { - ctx.Log.Logf("node", "ACL_DEQUEUE_ERROR: timeout signal not in queue when trying to clear after counter hit 0 %s, %s - %s", err, signal.ID(), req_info.TimeoutID) - } - delete(node.PendingACLs, info.ID) - } else { - node.PendingACLs[info.ID] = req_info - continue - } - } - } - } - } - switch sig := signal.(type) { case *ReadSignal: result := node.ReadFields(ctx, sig.Extensions) - msgs := Messages{} - msgs = msgs.Add(ctx, source, node, nil, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)) - ctx.Send(msgs) + msgs := []SendMsg{} + msgs = append(msgs, SendMsg{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)}) + ctx.Send(node, msgs) default: err := node.Process(ctx, source, signal) @@ -522,7 +322,7 @@ func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal) - messages := Messages{} + messages := []SendMsg{} changes := map[ExtType]Changes{} for ext_type, ext := range(node.Extensions) { ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) @@ -537,7 +337,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { ctx.Log.Logf("changes", "Changes for %s after %+v - %+v", node.ID, reflect.TypeOf(signal), changes) if len(messages) != 0 { - send_err := ctx.Send(messages) + send_err := ctx.Send(node, messages) if send_err != nil { return send_err } @@ -596,7 +396,7 @@ func KeyID(pub ed25519.PublicKey) NodeID { } // Create a new node in memory and start it's event loop -func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) { +func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, extensions ...Extension) (*Node, error) { node_type, known_type := ctx.NodeTypes[type_name] if known_type == false { return nil, fmt.Errorf("%s is not a known node type", type_name) @@ -643,17 +443,12 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size } } - policies = append(policies, DefaultPolicy) - node := &Node{ Key: key, ID: id, Type: node_type, Extensions: ext_map, - Policies: policies, - PendingACLs: map[uuid.UUID]PendingACL{}, - PendingACLSignals: map[uuid.UUID]PendingACLSignal{}, - MsgChan: make(chan *Message, buffer_size), + MsgChan: make(chan RecvMsg, buffer_size), BufferSize: buffer_size, SignalQueue: []QueuedSignal{}, } @@ -685,256 +480,17 @@ func ExtTypeSuffix(ext_type ExtType) []byte { } func WriteNodeExtList(ctx *Context, node *Node) error { - ext_list := make([]ExtType, len(node.Extensions)) - i := 0 - for ext_type := range(node.Extensions) { - ext_list[i] = ext_type - i += 1 - } - - ctx.Log.Logf("db", "Writing ext_list for %s - %+v", node.ID, ext_list) - - id_bytes, err := node.ID.MarshalBinary() - if err != nil { - return err - } - - ext_list_serialized, err := SerializeAny(ctx, ext_list) - if err != nil { - return err - } - - return ctx.DB.Update(func(txn *badger.Txn) error { - return txn.Set(append(id_bytes, extension_suffix...), ext_list_serialized.Data) - }) + return fmt.Errorf("TODO: write node list") } func WriteNodeInit(ctx *Context, node *Node) error { - ctx.Log.Logf("db", "Writing initial entry for %s - %+v", node.ID, node) - - ext_serialized := map[ExtType]SerializedValue{} - for ext_type, ext := range(node.Extensions) { - serialized_ext, err := SerializeAny(ctx, ext) - if err != nil { - return err - } - ext_serialized[ext_type] = serialized_ext - } - - sq_serialized, err := SerializeAny(ctx, node.SignalQueue) - if err != nil { - return err - } - - node_serialized, err := SerializeAny(ctx, node) - if err != nil { - return err - } - - id_bytes, err := node.ID.MarshalBinary() - - return ctx.DB.Update(func(txn *badger.Txn) error { - err := txn.Set(id_bytes, node_serialized.Data) - if err != nil { - return nil - } - - err = txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data) - if err != nil { - return err - } - - for ext_type, data := range(ext_serialized) { - err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data) - if err != nil { - return err - } - } - - return nil - }) - + return fmt.Errorf("TODO: write initial node entry") } func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { - ctx.Log.Logf("db", "Writing changes for %s - %+v", node.ID, changes) - - ext_serialized := map[ExtType]SerializedValue{} - for ext_type := range(changes) { - ext, ext_exists := node.Extensions[ext_type] - if ext_exists == false { - ctx.Log.Logf("db", "extension 0x%x does not exist for %s", ext_type, node.ID) - } else { - serialized_ext, err := SerializeAny(ctx, ext) - if err != nil { - return err - } - ext_serialized[ext_type] = serialized_ext - } - } - - var sq_serialized *SerializedValue = nil - if node.writeSignalQueue == true { - node.writeSignalQueue = false - ser, err := SerializeAny(ctx, node.SignalQueue) - if err != nil { - return err - } - sq_serialized = &ser - } - - node_serialized, err := SerializeAny(ctx, node) - if err != nil { - return err - } - - id_bytes, err := node.ID.MarshalBinary() - return ctx.DB.Update(func(txn *badger.Txn) error { - err := txn.Set(id_bytes, node_serialized.Data) - if err != nil { - return err - } - if sq_serialized != nil { - err := txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data) - if err != nil { - return err - } - } - for ext_type, data := range(ext_serialized) { - err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data) - if err != nil { - return err - } - } - return nil - }) + return fmt.Errorf("TODO: write changes to node(and any signal queue changes)") } func LoadNode(ctx *Context, id NodeID) (*Node, error) { - ctx.Log.Logf("db", "LOADING_NODE: %s", id) - var node_bytes []byte = nil - var sq_bytes []byte = nil - var ext_bytes = map[ExtType][]byte{} - - err := ctx.DB.View(func(txn *badger.Txn) error { - id_bytes, err := id.MarshalBinary() - if err != nil { - return err - } - - node_item, err := txn.Get(id_bytes) - if err != nil { - ctx.Log.Logf("db", "node key not found") - return err - } - - node_bytes, err = node_item.ValueCopy(nil) - if err != nil { - return err - } - - sq_item, err := txn.Get(append(id_bytes, signal_queue_suffix...)) - if err != nil { - ctx.Log.Logf("db", "sq key not found") - return err - } - sq_bytes, err = sq_item.ValueCopy(nil) - if err != nil { - return err - } - - ext_list_item, err := txn.Get(append(id_bytes, extension_suffix...)) - if err != nil { - ctx.Log.Logf("db", "ext_list key not found") - return err - } - - ext_list_bytes, err := ext_list_item.ValueCopy(nil) - if err != nil { - return err - } - - ext_list_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]ExtType{}), ext_list_bytes) - if err != nil { - return err - } else if len(remaining) > 0 { - return fmt.Errorf("Data remaining after ext_list deserialize %d", len(remaining)) - } - ext_list, ok := ext_list_value.Interface().([]ExtType) - if ok == false { - return fmt.Errorf("deserialize returned wrong type %s", ext_list_value.Type()) - } - - for _, ext_type := range(ext_list) { - ext_item, err := txn.Get(append(id_bytes, ExtTypeSuffix(ext_type)...)) - if err != nil { - ctx.Log.Logf("db", "ext %s key not found", ext_type) - return err - } - - ext_bytes[ext_type], err = ext_item.ValueCopy(nil) - if err != nil { - return err - } - } - return nil - }) - if err != nil { - return nil, err - } - - node_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf((*Node)(nil)), node_bytes) - if err != nil { - return nil, err - } else if len(remaining) != 0 { - return nil, fmt.Errorf("data left after deserializing node %d", len(remaining)) - } - - node, node_ok := node_value.Interface().(*Node) - if node_ok == false { - return nil, fmt.Errorf("node wrong type %s", node_value.Type()) - } - - ctx.Log.Logf("db", "Deserialized node bytes %+v", node) - - signal_queue_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]QueuedSignal{}), sq_bytes) - if err != nil { - return nil, err - } else if len(remaining) != 0 { - return nil, fmt.Errorf("data left after deserializing signal_queue %d", len(remaining)) - } - - signal_queue, sq_ok := signal_queue_value.Interface().([]QueuedSignal) - if sq_ok == false { - return nil, fmt.Errorf("signal queue wrong type %s", signal_queue_value.Type()) - } - - for ext_type, data := range(ext_bytes) { - ext_info, exists := ctx.Extensions[ext_type] - if exists == false { - return nil, fmt.Errorf("0x%0x is not a known extension type", ext_type) - } - - ext_value, remaining, err := DeserializeValue(ctx, ext_info.Reflect, data) - if err != nil { - return nil, err - } else if len(remaining) > 0 { - return nil, fmt.Errorf("data left after deserializing ext(0x%x) %d", ext_type, len(remaining)) - } - ext, ext_ok := ext_value.Interface().(Extension) - if ext_ok == false { - return nil, fmt.Errorf("extension wrong type %s", ext_value.Type()) - } - - node.Extensions[ext_type] = ext - } - - node.SignalQueue = signal_queue - node.NextSignal, node.TimeoutChan = SoonestSignal(signal_queue) - - ctx.AddNode(id, node) - ctx.Log.Logf("db", "loaded %+v", node) - go runNode(ctx, node) - - return node, nil + return nil, fmt.Errorf("TODO: load node + extensions from DB") } diff --git a/node_test.go b/node_test.go index f78f67b..da4cd89 100644 --- a/node_test.go +++ b/node_test.go @@ -12,7 +12,7 @@ func TestNodeDB(t *testing.T) { ctx := logTestContext(t, []string{"node", "db"}) node_listener := NewListenerExt(10) - node, err := NewNode(ctx, nil, "Base", 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener) + node, err := NewNode(ctx, nil, "Base", 10, nil, NewLockableExt(nil), node_listener) fatalErr(t, err) _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { @@ -45,25 +45,18 @@ func TestNodeRead(t *testing.T) { ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N2: %s", n2_id) - n1_policy := NewPerNodePolicy(map[NodeID]Tree{ - n2_id: { - SerializedType(SignalTypeFor[ReadSignal]()): nil, - }, - }) - n2_listener := NewListenerExt(10) - n2, err := NewNode(ctx, n2_key, "Base", 10, nil, NewGroupExt(nil), n2_listener) + n2, err := NewNode(ctx, n2_key, "Base", 10, n2_listener) fatalErr(t, err) - n1, err := NewNode(ctx, n1_key, "Base", 10, []Policy{n1_policy}, NewGroupExt(nil)) + n1, err := NewNode(ctx, n1_key, "Base", 10, NewListenerExt(10)) fatalErr(t, err) read_sig := NewReadSignal(map[ExtType][]string{ - ExtTypeFor[GroupExt](): {"members"}, + ExtTypeFor[ListenerExt](): {"buffer"}, }) - msgs := Messages{} - msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig) - err = ctx.Send(msgs) + msgs := []SendMsg{{n1.ID, read_sig}} + err = ctx.Send(n2, msgs) fatalErr(t, err) res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool { diff --git a/policy.go b/policy.go deleted file mode 100644 index 2ac65b3..0000000 --- a/policy.go +++ /dev/null @@ -1,139 +0,0 @@ -package graphvent - -import ( - "github.com/google/uuid" -) - -type Policy interface { - Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) - ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult - ID() uuid.UUID -} - -type PolicyHeader struct { - UUID uuid.UUID `gv:"uuid"` -} - -func (header PolicyHeader) ID() uuid.UUID { - return header.UUID -} - -func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - return nil, policy.Rules.Allows(action) -} - -func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - return Deny -} - -func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - for id, actions := range(policy.NodeRules) { - if id != principal_id { - continue - } - return nil, actions.Allows(action) - } - return nil, Deny -} - -func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { - return Deny -} - -func CopyTree(tree Tree) Tree { - if tree == nil { - return nil - } - - ret := Tree{} - for name, sub := range(tree) { - ret[name] = CopyTree(sub) - } - - return ret -} - -func MergeTrees(first Tree, second Tree) Tree { - if first == nil || second == nil { - return nil - } - - ret := CopyTree(first) - for name, sub := range(second) { - current, exists := ret[name] - if exists == true { - ret[name] = MergeTrees(current, sub) - } else { - ret[name] = CopyTree(sub) - } - } - return ret -} - -type Tree map[SerializedType]Tree - -func (rule Tree) Allows(action Tree) RuleResult { - // If the current rule is nil, it's a wildcard and any action being processed is allowed - if rule == nil { - return Allow - // If the rule isn't "allow all" but the action is "request all", deny - } else if action == nil { - return Deny - // If the current action has no children, it's allowed - } else if len(action) == 0 { - return Allow - // If the current rule has no children but the action goes further, it's not allowed - } else if len(rule) == 0 { - return Deny - // If the current rule and action have children, all the children of action must be allowed by rule - } else { - for sub, subtree := range(action) { - subrule, exists := rule[sub] - if exists == false { - return Deny - } else if subrule.Allows(subtree) == Deny { - return Deny - } - } - return Allow - } -} - -func NewPolicyHeader() PolicyHeader { - return PolicyHeader{ - UUID: uuid.New(), - } -} - -func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy { - if node_actions == nil { - node_actions = map[NodeID]Tree{} - } - - return PerNodePolicy{ - PolicyHeader: NewPolicyHeader(), - NodeRules: node_actions, - } -} - -type PerNodePolicy struct { - PolicyHeader - NodeRules map[NodeID]Tree `gv:"node_rules"` -} - -func NewAllNodesPolicy(rules Tree) AllNodesPolicy { - return AllNodesPolicy{ - PolicyHeader: NewPolicyHeader(), - Rules: rules, - } -} - -type AllNodesPolicy struct { - PolicyHeader - Rules Tree `gv:"rules"` -} - -var DefaultPolicy = NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[ResponseSignal]()): nil, - SerializedType(SignalTypeFor[StatusSignal]()): nil, -}) diff --git a/serialize.go b/serialize.go index 1b56187..1f7fa3b 100644 --- a/serialize.go +++ b/serialize.go @@ -1,16 +1,11 @@ package graphvent import ( - "bytes" "crypto/sha512" - "encoding" "encoding/binary" - "encoding/gob" "fmt" - "math" "reflect" "slices" - "sort" ) type SerializedType uint64 @@ -37,134 +32,12 @@ func (t SignalType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } -type PolicyType SerializedType - -func (t PolicyType) String() string { - return fmt.Sprintf("0x%x", uint64(t)) -} - type FieldTag SerializedType func (t FieldTag) String() string { return fmt.Sprintf("0x%x", uint64(t)) } -type Chunk struct { - Data []byte - Next *Chunk -} - -type Chunks struct { - First *Chunk - Last *Chunk -} - -func (chunks Chunks) String() string { - cur := chunks.First - str := fmt.Sprintf("Chunks(") - for cur != nil { - str = fmt.Sprintf("%s%+v, ", str, cur) - cur = cur.Next - } - - return fmt.Sprintf("%s)", str) -} - -func NewChunks(datas ...[]byte) Chunks { - var first *Chunk = nil - var last *Chunk = nil - - if len(datas) >= 1 { - first = &Chunk{ - Data: datas[0], - Next: nil, - } - last = first - - for _, data := range(datas[1:]) { - last.Next = &Chunk{ - Data: data, - Next: nil, - } - last = last.Next - } - } - - if (first == nil || last == nil) && (first != last) { - panic(fmt.Sprintf("Attempted to construct invalid Chunks with NewChunks %+v - %+v", first, last)) - } - return Chunks{ - First: first, - Last: last, - } -} - -func (chunks Chunks) AddDataToEnd(datas ...[]byte) Chunks { - if chunks.First == nil && chunks.Last == nil { - return NewChunks(datas...) - } else if chunks.First == nil || chunks.Last == nil { - panic(fmt.Sprintf("Invalid chunks %+v", chunks)) - } - - for _, data := range(datas) { - chunks.Last.Next = &Chunk{ - Data: data, - Next: nil, - } - chunks.Last = chunks.Last.Next - } - - return chunks -} - -func (chunks Chunks) AddChunksToEnd(new_chunks Chunks) Chunks { - if chunks.Last == nil && chunks.First == nil { - return new_chunks - } else if chunks.Last == nil || chunks.First == nil { - panic(fmt.Sprintf("Invalid chunks %+v", chunks)) - } else if new_chunks.Last == nil && new_chunks.First == nil { - return chunks - } else if new_chunks.Last == nil || new_chunks.First == nil { - panic(fmt.Sprintf("Invalid new_chunks %+v", new_chunks)) - } else { - chunks.Last.Next = new_chunks.First - chunks.Last = new_chunks.Last - return chunks - } -} - -func (chunks Chunks) GetSerializedSize() int { - total_size := 0 - cur := chunks.First - - for cur != nil { - total_size += len(cur.Data) - cur = cur.Next - } - return total_size -} - -func (chunks Chunks) Slice() []byte { - total_size := chunks.GetSerializedSize() - data := make([]byte, total_size) - data_ptr := 0 - - cur := chunks.First - for cur != nil { - copy(data[data_ptr:], cur.Data) - data_ptr += len(cur.Data) - cur = cur.Next - } - - return data -} - -type TypeSerializeFn func(*Context, reflect.Type) ([]SerializedType, error) -type SerializeFn func(*Context, reflect.Value) (Chunks, error) -type TypeDeserializeFn func(*Context, []SerializedType) (reflect.Type, []SerializedType, error) -type DeserializeFn func(*Context, reflect.Type, []byte) (reflect.Value, []byte, error) - - func NodeTypeFor(name string, extensions []ExtType, mappings map[string]FieldIndex) NodeType { digest := []byte("GRAPHVENT_NODE[" + name + "] - ") for _, ext := range(extensions) { @@ -193,15 +66,9 @@ func NodeTypeFor(name string, extensions []ExtType, mappings map[string]FieldInd return NodeType(binary.BigEndian.Uint64(hash[0:8])) } -func SerializedKindFor(kind reflect.Kind) SerializedType { - digest := []byte("GRAPHVENT_KIND - " + kind.String()) - hash := sha512.Sum512(digest) - return SerializedType(binary.BigEndian.Uint64(hash[0:8])) -} - func SerializedTypeFor[T any]() SerializedType { t := reflect.TypeFor[T]() - digest := []byte(t.PkgPath() + ":" + t.Name()) + digest := []byte(t.String()) hash := sha512.Sum512(digest) return SerializedType(binary.BigEndian.Uint64(hash[0:8])) } @@ -214,10 +81,6 @@ func SignalTypeFor[S Signal]() SignalType { return SignalType(SerializedTypeFor[S]()) } -func PolicyTypeFor[P Policy]() PolicyType { - return PolicyType(SerializedTypeFor[P]()) -} - func Hash(base, data string) SerializedType { digest := []byte(base + ":" + data) hash := sha512.Sum512(digest) @@ -227,957 +90,3 @@ func Hash(base, data string) SerializedType { func GetFieldTag(tag string) FieldTag { return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag)) } - -type FieldInfo struct { - Index []int - TypeStack []SerializedType - Type reflect.Type -} - -type StructInfo struct { - Type reflect.Type - FieldOrder []FieldTag - FieldMap map[FieldTag]FieldInfo - PostDeserialize bool - PostDeserializeIdx int -} - -type Deserializable interface { - PostDeserialize(*Context) error -} - -var deserializable_zero Deserializable = nil -var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() - -func GetStructInfo(ctx *Context, struct_type reflect.Type) (StructInfo, error) { - field_order := []FieldTag{} - field_map := map[FieldTag]FieldInfo{} - for _, field := range reflect.VisibleFields(struct_type) { - gv_tag, tagged_gv := field.Tag.Lookup("gv") - if tagged_gv == false { - continue - } else { - field_tag := GetFieldTag(gv_tag) - _, exists := field_map[field_tag] - if exists == true { - return StructInfo{}, fmt.Errorf("gv tag %s is repeated", gv_tag) - } else { - field_type_stack, err := SerializeType(ctx, field.Type) - if err != nil { - return StructInfo{}, err - } - field_map[field_tag] = FieldInfo{ - field.Index, - field_type_stack, - field.Type, - } - field_order = append(field_order, field_tag) - } - } - } - - sort.Slice(field_order, func(i, j int) bool { - return uint64(field_order[i]) < uint64(field_order[j]) - }) - - post_deserialize := false - post_deserialize_idx := 0 - ptr_type := reflect.PointerTo(struct_type) - if ptr_type.Implements(DeserializableType) { - post_deserialize = true - for i := 0; i < ptr_type.NumMethod(); i += 1 { - method := ptr_type.Method(i) - if method.Name == "PostDeserialize" { - post_deserialize_idx = i - break - } - } - } - - return StructInfo{ - struct_type, - field_order, - field_map, - post_deserialize, - post_deserialize_idx, - }, nil -} - -func SerializeStruct(info StructInfo)func(*Context, reflect.Value)(Chunks, error) { - return func(ctx *Context, value reflect.Value) (Chunks, error) { - struct_chunks := Chunks{} - for _, field_hash := range(info.FieldOrder) { - field_hash_bytes := make([]byte, 8) - binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) - - field_info := info.FieldMap[field_hash] - field_value := value.FieldByIndex(field_info.Index) - - field_chunks, err := SerializeValue(ctx, field_value) - if err != nil { - return Chunks{}, err - } - - struct_chunks = struct_chunks.AddDataToEnd(field_hash_bytes).AddChunksToEnd(field_chunks) - ctx.Log.Logf("serialize", "STRUCT_FIELD_CHUNKS: %+v", field_chunks) - } - size_data := make([]byte, 8) - binary.BigEndian.PutUint64(size_data, uint64(len(info.FieldOrder))) - return NewChunks(size_data).AddChunksToEnd(struct_chunks), nil - } -} - -func DeserializeStruct(info StructInfo)func(*Context, reflect.Type, []byte)(reflect.Value, []byte, error) { - return func(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize struct %d/8", len(data)) - } - - num_field_bytes := data[:8] - data = data[8:] - - num_fields := binary.BigEndian.Uint64(num_field_bytes) - - struct_value := reflect.New(reflect_type).Elem() - for i := uint64(0); i < num_fields; i ++ { - field_hash_bytes := data[:8] - data = data[8:] - field_tag := FieldTag(binary.BigEndian.Uint64(field_hash_bytes)) - field_info, exists := info.FieldMap[field_tag] - if exists == false { - return reflect.Value{}, nil, fmt.Errorf("%+v is not a field in %+v", field_tag, info.Type) - } - - var field_value reflect.Value - var err error - field_value, data, err = DeserializeValue(ctx, field_info.Type, data) - if err != nil { - return reflect.Value{}, nil, err - } - - field_reflect := struct_value.FieldByIndex(field_info.Index) - field_reflect.Set(field_value) - } - - if info.PostDeserialize == true { - post_deserialize_method := struct_value.Addr().Method(info.PostDeserializeIdx) - results := post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)}) - err_if := results[0].Interface() - if err_if != nil { - return reflect.Value{}, nil, err_if.(error) - } - } - - return struct_value, data, nil - } -} - -func SerializeGob(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 8) - gob_ser, err := value.Interface().(gob.GobEncoder).GobEncode() - if err != nil { - return Chunks{}, err - } - - binary.BigEndian.PutUint64(data, uint64(len(gob_ser))) - return NewChunks(data, gob_ser), nil -} - -func DeserializeGob[T any, PT interface{gob.GobDecoder; *T}](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to deserialize gob %d/8", len(data)) - } - - size_bytes := data[:8] - size := binary.BigEndian.Uint64(size_bytes) - gob_data := data[8:8+size] - data = data[8+size:] - - gob_ptr := reflect.New(reflect_type) - err := gob_ptr.Interface().(gob.GobDecoder).GobDecode(gob_data) - if err != nil { - return reflect.Value{}, nil, err - } - - return gob_ptr.Elem(), data, nil -} - -func SerializeInt8(ctx *Context, value reflect.Value) (Chunks, error) { - data := []byte{byte(value.Int())} - - return NewChunks(data), nil -} - -func SerializeInt16(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 2) - binary.BigEndian.PutUint16(data, uint16(value.Int())) - - return NewChunks(data), nil -} - -func SerializeInt32(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(value.Int())) - - return NewChunks(data), nil -} - -func SerializeInt64(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Int())) - - return NewChunks(data), nil -} - -func SerializeUint8(ctx *Context, value reflect.Value) (Chunks, error) { - data := []byte{byte(value.Uint())} - - return NewChunks(data), nil -} - -func SerializeUint16(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 2) - binary.BigEndian.PutUint16(data, uint16(value.Uint())) - - return NewChunks(data), nil -} - -func SerializeUint32(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(value.Uint())) - - return NewChunks(data), nil -} - -func SerializeUint64(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, value.Uint()) - - return NewChunks(data), nil -} - -func DeserializeUint64[T ~uint64 | ~int64](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - uint_size := 8 - if len(data) < uint_size { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) - } - - uint_bytes := data[:uint_size] - data = data[uint_size:] - uint_value := reflect.New(reflect_type).Elem() - - typed_value := T(binary.BigEndian.Uint64(uint_bytes)) - uint_value.Set(reflect.ValueOf(typed_value)) - - return uint_value, data, nil -} - -func DeserializeUint32[T ~uint32 | ~uint | ~int32 | ~int](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - uint_size := 4 - if len(data) < uint_size { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) - } - - uint_bytes := data[:uint_size] - data = data[uint_size:] - uint_value := reflect.New(reflect_type).Elem() - - typed_value := T(binary.BigEndian.Uint32(uint_bytes)) - uint_value.Set(reflect.ValueOf(typed_value)) - - return uint_value, data, nil -} - -func DeserializeUint16[T ~uint16 | ~int16](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - uint_size := 2 - if len(data) < uint_size { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) - } - - uint_bytes := data[:uint_size] - data = data[uint_size:] - uint_value := reflect.New(reflect_type).Elem() - - typed_value := T(binary.BigEndian.Uint16(uint_bytes)) - uint_value.Set(reflect.ValueOf(typed_value)) - - return uint_value, data, nil -} - -func DeserializeUint8[T ~uint8 | ~int8](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - uint_size := 1 - if len(data) < uint_size { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) - } - - uint_bytes := data[:uint_size] - data = data[uint_size:] - uint_value := reflect.New(reflect_type).Elem() - - typed_value := T(uint_bytes[0]) - uint_value.Set(reflect.ValueOf(typed_value)) - - return uint_value, data, nil -} - -func SerializeFloat64(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 8) - float_representation := math.Float64bits(value.Float()) - binary.BigEndian.PutUint64(data, float_representation) - return NewChunks(data), nil -} - -func DeserializeFloat64[T ~float64](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize float64 %d/8", len(data)) - } - - float_bytes := data[0:8] - data = data[8:] - - float_representation := binary.BigEndian.Uint64(float_bytes) - float := math.Float64frombits(float_representation) - - float_value := reflect.New(reflect_type).Elem() - float_value.Set(reflect.ValueOf(T(float))) - - return float_value, data, nil -} - -func SerializeFloat32(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 4) - float_representation := math.Float32bits(float32(value.Float())) - binary.BigEndian.PutUint32(data, float_representation) - return NewChunks(data), nil -} - -func DeserializeFloat32[T ~float32](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 4 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize float32 %d/4", len(data)) - } - - float_bytes := data[0:4] - data = data[4:] - - float_representation := binary.BigEndian.Uint32(float_bytes) - float := math.Float32frombits(float_representation) - - float_value := reflect.New(reflect_type).Elem() - float_value.Set(reflect.ValueOf(T(float))) - - return float_value, data, nil -} - -func SerializeString(ctx *Context, value reflect.Value) (Chunks, error) { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Len())) - - return NewChunks(data, []byte(value.String())), nil -} - -func DeserializeString[T ~string](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize string %d/8", len(data)) - } - - size_bytes := data[0:8] - data = data[8:] - - size := binary.BigEndian.Uint64(size_bytes) - if len(data) < int(size) { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize string of len %d, %d/%d", size, len(data), size) - } - - string_value := reflect.New(reflect_type).Elem() - string_value.Set(reflect.ValueOf(T(string(data[:size])))) - data = data[size:] - - return string_value, data, nil -} - -func SerializeBool(ctx *Context, value reflect.Value) (Chunks, error) { - if value.Bool() == true { - return NewChunks([]byte{0xFF}), nil - } else { - return NewChunks([]byte{0x00}), nil - } -} - -func DeserializeBool[T ~bool](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 1 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize bool %d/1", len(data)) - } - byte := data[0] - data = data[1:] - - bool_value := reflect.New(reflect_type).Elem() - if byte == 0x00 { - bool_value.Set(reflect.ValueOf(T(false))) - } else { - bool_value.Set(reflect.ValueOf(T(true))) - } - - return bool_value, data, nil -} - -func DeserializeTypePointer(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - elem_type, remaining, err := DeserializeType(ctx, type_stack) - if err != nil { - return nil, nil, err - } - - return reflect.PointerTo(elem_type), remaining, nil -} - -func SerializePointer(ctx *Context, value reflect.Value) (Chunks, error) { - if value.IsZero() { - return NewChunks([]byte{0x00}), nil - } else { - flags := NewChunks([]byte{0x01}) - - elem_chunks, err := SerializeValue(ctx, value.Elem()) - if err != nil { - return Chunks{}, err - } - - return flags.AddChunksToEnd(elem_chunks), nil - } -} - -func DeserializePointer(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 1 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize pointer %d/1", len(data)) - } - - flags := data[0] - data = data[1:] - - pointer_value := reflect.New(reflect_type).Elem() - - if flags != 0x00 { - var element_value reflect.Value - var err error - element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) - if err != nil { - return reflect.Value{}, nil, err - } - - pointer_value.Set(element_value.Addr()) - } - - return pointer_value, data, nil -} - -func SerializeTypeStub(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { - return nil, nil -} - -func DeserializeTypeStub[T any](ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - var zero T - return reflect.TypeOf(zero), type_stack, nil -} - -func SerializeTypeElem(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { - return SerializeType(ctx, reflect_type.Elem()) -} - -func SerializeSlice(ctx *Context, value reflect.Value) (Chunks, error) { - if value.IsZero() { - return NewChunks([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}), nil - } else if value.Len() == 0 { - return NewChunks([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), nil - } else { - slice_chunks := Chunks{} - for i := 0; i < value.Len(); i += 1 { - val := value.Index(i) - element_chunks, err := SerializeValue(ctx, val) - if err != nil { - return Chunks{}, err - } - slice_chunks = slice_chunks.AddChunksToEnd(element_chunks) - } - - size_data := make([]byte, 8) - binary.BigEndian.PutUint64(size_data, uint64(value.Len())) - - return NewChunks(size_data).AddChunksToEnd(slice_chunks), nil - } -} - -func DeserializeTypeSlice(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - elem_type, remaining, err := DeserializeType(ctx, type_stack) - if err != nil { - return nil, nil, err - } - - reflect_type := reflect.SliceOf(elem_type) - return reflect_type, remaining, nil -} - -func DeserializeSlice(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize slice %d/8", len(data)) - } - - slice_size := binary.BigEndian.Uint64(data[0:8]) - slice_value := reflect.New(reflect_type).Elem() - data = data[8:] - - if slice_size != 0xFFFFFFFFFFFFFFFF { - slice_unaddr := reflect.MakeSlice(reflect_type, int(slice_size), int(slice_size)) - slice_value.Set(slice_unaddr) - for i := uint64(0); i < slice_size; i += 1 { - var element_value reflect.Value - var err error - element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) - if err != nil { - return reflect.Value{}, nil, err - } - - slice_elem := slice_value.Index(int(i)) - slice_elem.Set(element_value) - } - } - - return slice_value, data, nil -} - -func SerializeTypeMap(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { - key_stack, err := SerializeType(ctx, reflect_type.Key()) - if err != nil { - return nil, err - } - - elem_stack, err := SerializeType(ctx, reflect_type.Elem()) - if err != nil { - return nil, err - } - - return append(key_stack, elem_stack...), nil -} - -func DeserializeTypeMap(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - key_type, after_key, err := DeserializeType(ctx, type_stack) - if err != nil { - return nil, nil, err - } - - elem_type, after_elem, err := DeserializeType(ctx, after_key) - if err != nil { - return nil, nil, err - } - - map_type := reflect.MapOf(key_type, elem_type) - return map_type, after_elem, nil -} - -func SerializeMap(ctx *Context, value reflect.Value) (Chunks, error) { - if value.IsZero() == true { - return NewChunks([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}), nil - } - - map_chunks := []Chunks{} - map_size := uint64(0) - map_iter := value.MapRange() - for map_iter.Next() { - map_size = map_size + 1 - key := map_iter.Key() - val := map_iter.Value() - - key_chunks, err := SerializeValue(ctx, key) - if err != nil { - return Chunks{}, err - } - - val_chunks, err := SerializeValue(ctx, val) - if err != nil { - return Chunks{}, err - } - - chunks := key_chunks.AddChunksToEnd(val_chunks) - map_chunks = append(map_chunks, chunks) - } - - // Sort map_chunks - sort.Slice(map_chunks, func(i, j int) bool { - return bytes.Compare(map_chunks[i].First.Data, map_chunks[j].First.Data) < 0 - }) - chunks := Chunks{} - for _, chunk := range(map_chunks) { - chunks = chunks.AddChunksToEnd(chunk) - } - - - size_data := make([]byte, 8) - binary.BigEndian.PutUint64(size_data, map_size) - - return NewChunks(size_data).AddChunksToEnd(chunks), nil -} - -func DeserializeMap(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 8 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize map %d/8", len(data)) - } - - size_bytes := data[:8] - data = data[8:] - - size := binary.BigEndian.Uint64(size_bytes) - - map_value := reflect.New(reflect_type).Elem() - if size == 0xFFFFFFFFFFFFFFFF { - return map_value, data, nil - } - - map_unaddr := reflect.MakeMapWithSize(reflect_type, int(size)) - map_value.Set(map_unaddr) - - for i := uint64(0); i < size; i++ { - var err error - var key_value reflect.Value - key_value, data, err = DeserializeValue(ctx, reflect_type.Key(), data) - if err != nil { - return reflect.Value{}, nil, err - } - - var val_value reflect.Value - val_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) - if err != nil { - return reflect.Value{}, nil, err - } - - map_value.SetMapIndex(key_value, val_value) - } - - return map_value, data, nil -} - -func SerializeTypeArray(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { - size := SerializedType(reflect_type.Len()) - elem_stack, err := SerializeType(ctx, reflect_type.Elem()) - if err != nil { - return nil, err - } - - return append([]SerializedType{size}, elem_stack...), nil -} - -func SerializeUUID(ctx *Context, value reflect.Value) (Chunks, error) { - uuid_ser, err := value.Interface().(encoding.BinaryMarshaler).MarshalBinary() - if err != nil { - return Chunks{}, err - } - - if len(uuid_ser) != 16 { - return Chunks{}, fmt.Errorf("Wrong length of uuid: %d/16", len(uuid_ser)) - } - - return NewChunks(uuid_ser), nil -} - -func DeserializeUUID[T ~[16]byte](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 16 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize UUID %d/16", len(data)) - } - - uuid_bytes := data[:16] - data = data[16:] - - uuid_value := reflect.New(reflect_type).Elem() - uuid_value.Set(reflect.ValueOf(T(uuid_bytes))) - - return uuid_value, data, nil -} - -func SerializeArray(ctx *Context, value reflect.Value) (Chunks, error) { - data := Chunks{} - for i := 0; i < value.Len(); i += 1 { - element := value.Index(i) - element_chunks, err := SerializeValue(ctx, element) - if err != nil { - return Chunks{}, err - } - data = data.AddChunksToEnd(element_chunks) - } - - return data, nil -} - -func DeserializeTypeArray(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - if len(type_stack) < 1 { - return nil, nil, fmt.Errorf("Not enough valued in type stack to deserialize array") - } - - size := int(type_stack[0]) - element_type, remaining, err := DeserializeType(ctx, type_stack[1:]) - if err != nil { - return nil, nil, err - } - - array_type := reflect.ArrayOf(size, element_type) - return array_type, remaining, nil -} - -func DeserializeArray(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - array_value := reflect.New(reflect_type).Elem() - for i := 0; i < array_value.Len(); i += 1 { - var element_value reflect.Value - var err error - element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) - if err != nil { - return reflect.Value{}, nil, err - } - - element := array_value.Index(i) - element.Set(element_value) - } - - return array_value, data, nil -} - -func SerializeInterface(ctx *Context, value reflect.Value) (Chunks, error) { - if value.IsZero() == true { - return NewChunks([]byte{0xFF}), nil - } - - type_stack, err := SerializeType(ctx, value.Elem().Type()) - if err != nil { - return Chunks{}, err - } - - elem_chunks, err := SerializeValue(ctx, value.Elem()) - if err != nil { - return Chunks{}, err - } - - data := elem_chunks.Slice() - - serialized_chunks, err := SerializedValue{type_stack, data}.Chunks() - if err != nil { - return Chunks{}, err - } - - return NewChunks([]byte{0x00}).AddChunksToEnd(serialized_chunks), nil -} - -func DeserializeInterface(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - if len(data) < 1 { - return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize interface %d/1", len(data)) - } - - flags := data[0] - data = data[1:] - if flags == 0xFF { - return reflect.New(reflect_type).Elem(), data, nil - } - - serialized_value, remaining, err := ParseSerializedValue(data) - elem_type, types_remaining, err := DeserializeType(ctx, serialized_value.TypeStack) - if err != nil { - return reflect.Value{}, nil, err - } else if len(types_remaining) > 0 { - return reflect.Value{}, nil, fmt.Errorf("Types remaining in interface stack after deserializing") - } - - elem_value, data_remaining, err := DeserializeValue(ctx, elem_type, serialized_value.Data) - if err != nil { - return reflect.Value{}, nil, err - } else if len(data_remaining) > 0 { - return reflect.Value{}, nil, fmt.Errorf("Data remaining in interface data after deserializing") - } - - interface_value := reflect.New(reflect_type).Elem() - interface_value.Set(elem_value) - - return interface_value, remaining, nil -} - -type SerializedValue struct { - TypeStack []SerializedType - Data []byte -} - -func SerializeAny[T any](ctx *Context, value T) (SerializedValue, error) { - reflect_value := reflect.ValueOf(value) - type_stack, err := SerializeType(ctx, reflect_value.Type()) - if err != nil { - return SerializedValue{}, err - } - data, err := SerializeValue(ctx, reflect_value) - if err != nil { - return SerializedValue{}, err - } - - return SerializedValue{type_stack, data.Slice()}, nil -} - -func SerializeType(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { - ctx.Log.Logf("serialize", "Serializing type %+v", reflect_type) - - type_info, type_exists := ctx.TypeReflects[reflect_type] - var serialize_type TypeSerializeFn = nil - var ctx_type SerializedType - if type_exists == true { - serialize_type = type_info.TypeSerialize - ctx_type = type_info.Type - } - - if serialize_type == nil { - kind_info, handled := ctx.Kinds[reflect_type.Kind()] - if handled == true { - if type_exists == false { - ctx_type = kind_info.Type - } - serialize_type = kind_info.TypeSerialize - } - } - - type_stack := []SerializedType{ctx_type} - if serialize_type != nil { - extra_types, err := serialize_type(ctx, reflect_type) - if err != nil { - return nil, err - } - return append(type_stack, extra_types...), nil - } else { - return type_stack, nil - } -} - -func SerializeValue(ctx *Context, value reflect.Value) (Chunks, error) { - type_info, type_exists := ctx.TypeReflects[value.Type()] - var serialize SerializeFn = nil - if type_exists == true { - if type_info.Serialize != nil { - serialize = type_info.Serialize - } - } - - if serialize == nil { - kind_info, handled := ctx.Kinds[value.Kind()] - if handled { - serialize = kind_info.Serialize - } else { - return Chunks{}, fmt.Errorf("Don't know how to serialize %+v", value.Type()) - } - } - - return serialize(ctx, value) -} - -func ExtField(ctx *Context, ext Extension, field_name string) (reflect.Value, error) { - if ext == nil { - return reflect.Value{}, fmt.Errorf("Cannot get fields on nil Extension") - } - - ext_value := reflect.ValueOf(ext).Elem() - for _, field := range reflect.VisibleFields(ext_value.Type()) { - gv_tag, tagged := field.Tag.Lookup("gv") - if tagged == true && gv_tag == field_name { - return ext_value.FieldByIndex(field.Index), nil - } - } - - return reflect.Value{}, fmt.Errorf("%s is not a field in %+v", field_name, reflect.TypeOf(ext)) -} - -func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { - field_value, err := ExtField(ctx, ext, field_name) - if err != nil { - return SerializedValue{}, err - } - type_stack, err := SerializeType(ctx, field_value.Type()) - if err != nil { - return SerializedValue{}, err - } - data, err := SerializeValue(ctx, field_value) - if err != nil { - return SerializedValue{}, err - } - return SerializedValue{type_stack, data.Slice()}, nil -} - -func (value SerializedValue) Chunks() (Chunks, error) { - header_data := make([]byte, 16) - binary.BigEndian.PutUint64(header_data[0:8], uint64(len(value.TypeStack))) - binary.BigEndian.PutUint64(header_data[8:16], uint64(len(value.Data))) - - type_stack_bytes := make([][]byte, len(value.TypeStack)) - for i, ctx_type := range(value.TypeStack) { - type_stack_bytes[i] = make([]byte, 8) - binary.BigEndian.PutUint64(type_stack_bytes[i], uint64(ctx_type)) - } - - return NewChunks(header_data).AddDataToEnd(type_stack_bytes...).AddDataToEnd(value.Data), nil -} - -func ParseSerializedValue(data []byte) (SerializedValue, []byte, error) { - if len(data) < 8 { - return SerializedValue{}, nil, fmt.Errorf("SerializedValue required to have at least 8 bytes when serialized") - } - num_types := int(binary.BigEndian.Uint64(data[0:8])) - data_size := int(binary.BigEndian.Uint64(data[8:16])) - type_stack := make([]SerializedType, num_types) - for i := 0; i < num_types; i += 1 { - type_start := (i + 2) * 8 - type_end := (i + 3) * 8 - type_stack[i] = SerializedType(binary.BigEndian.Uint64(data[type_start:type_end])) - } - - types_end := 8 * (num_types + 2) - data_end := types_end + data_size - return SerializedValue{ - type_stack, - data[types_end:data_end], - }, data[data_end:], nil -} - -func DeserializeValue(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { - ctx.Log.Logf("serialize", "Deserializing %+v with %d bytes", reflect_type, len(data)) - var deserialize DeserializeFn = nil - - type_info, type_exists := ctx.TypeReflects[reflect_type] - if type_exists == true { - deserialize = type_info.Deserialize - } else { - kind_info, exists := ctx.Kinds[reflect_type.Kind()] - if exists == false { - return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize %+v/%+v: unknown type/kind", reflect_type, reflect_type.Kind()) - } - deserialize = kind_info.Deserialize - } - - return deserialize(ctx, reflect_type, data) -} - -func DeserializeType(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { - ctx.Log.Logf("deserialize_types", "Deserializing type stack %+v", type_stack) - var deserialize_type TypeDeserializeFn = nil - var reflect_type reflect.Type = nil - - if len(type_stack) < 1 { - return nil, nil, fmt.Errorf("No elements in type stack to deserialize(DeserializeType)") - } - - ctx_type := type_stack[0] - type_stack = type_stack[1:] - - type_info, type_exists := ctx.Types[SerializedType(ctx_type)] - if type_exists == true { - deserialize_type = type_info.TypeDeserialize - reflect_type = type_info.Reflect - } else { - kind_info, exists := ctx.KindTypes[SerializedType(ctx_type)] - if exists == false { - return nil, nil, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) - } - deserialize_type = kind_info.TypeDeserialize - reflect_type = kind_info.Base - } - - if deserialize_type == nil { - return reflect_type, type_stack, nil - } else { - return deserialize_type(ctx, type_stack) - } -} diff --git a/serialize_test.go b/serialize_test.go deleted file mode 100644 index e969088..0000000 --- a/serialize_test.go +++ /dev/null @@ -1,247 +0,0 @@ -package graphvent - -import ( - "fmt" - "reflect" - "testing" - "time" -) - -func TestSerializeTest(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) - testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}}) - testSerialize(t, ctx, map[NodeID]ReqState{ - RandID(): Locked, - RandID(): Unlocked, - }) -} - -func TestSerializeBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) - testSerializeComparable[bool](t, ctx, true) - - type bool_wrapped bool - err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped]) - fatalErr(t, err) - testSerializeComparable[bool_wrapped](t, ctx, true) - - testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false}) - testSerializeComparable[string](t, ctx, "test") - testSerializeComparable[float32](t, ctx, 0.05) - testSerializeComparable[float64](t, ctx, 0.05) - testSerializeComparable[uint](t, ctx, uint(1234)) - testSerializeComparable[uint8] (t, ctx, uint8(123)) - testSerializeComparable[uint16](t, ctx, uint16(1234)) - testSerializeComparable[uint32](t, ctx, uint32(12345)) - testSerializeComparable[uint64](t, ctx, uint64(123456)) - testSerializeComparable[int](t, ctx, 1234) - testSerializeComparable[int8] (t, ctx, int8(-123)) - testSerializeComparable[int16](t, ctx, int16(-1234)) - testSerializeComparable[int32](t, ctx, int32(-12345)) - testSerializeComparable[int64](t, ctx, int64(-123456)) - testSerializeComparable[time.Duration](t, ctx, time.Duration(100)) - testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0)) - testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112}) - testSerializeSlice[[]int](t, ctx, ([]int)(nil)) - testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil}) - testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil}) - - testSerialize(t, ctx, map[int8]map[*int8]string{}) - testSerialize(t, ctx, map[int8]time.Time{ - 1: time.Now(), - 3: time.Now().Add(time.Second), - 0: time.Now().Add(time.Second*2), - 4: time.Now().Add(time.Second*3), - }) - - testSerialize(t, ctx, Tree{ - SerializedTypeFor[NodeType](): nil, - SerializedTypeFor[SerializedType](): { - SerializedTypeFor[NodeType](): Tree{}, - }, - }) - - var i interface{} = nil - testSerialize(t, ctx, i) - - testSerializeMap(t, ctx, map[int8]interface{}{ - 0: "abcd", - 1: uint32(12345678), - 2: i, - 3: 123, - }) - - testSerializeMap(t, ctx, map[int8]int32{ - 0: 1234, - 2: 5678, - 4: 9101, - 6: 1121, - }) - - type test_struct struct { - Int int `gv:"int"` - String string `gv:"string"` - } - - err = RegisterStruct[test_struct](ctx) - fatalErr(t, err) - - testSerialize(t, ctx, test_struct{ - 12345, - "test_string", - }) - - testSerialize(t, ctx, Tree{ - SerializedKindFor(reflect.Map): nil, - SerializedKindFor(reflect.String): nil, - }) - - testSerialize(t, ctx, Tree{ - SerializedTypeFor[Tree](): nil, - }) - - testSerialize(t, ctx, Tree{ - SerializedTypeFor[Tree](): { - SerializedTypeFor[error](): Tree{}, - SerializedKindFor(reflect.Map): nil, - }, - SerializedKindFor(reflect.String): nil, - }) - - type test_slice []string - err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice) - fatalErr(t, err) - - testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"}) - testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"}) -} - -type test struct { - Int int `gv:"int"` - Str string `gv:"string"` -} - -func (s test) String() string { - return fmt.Sprintf("%d:%s", s.Int, s.Str) -} - -func TestSerializeStructTags(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) - - err := RegisterStruct[test](ctx) - fatalErr(t, err) - - test_int := 10 - test_string := "test" - - ret := testSerialize(t, ctx, test{ - test_int, - test_string, - }) - if ret.Int != test_int { - t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int) - } else if ret.Str != test_string { - t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string) - } - - testSerialize(t, ctx, []test{ - { - test_int, - test_string, - }, - { - test_int * 2, - fmt.Sprintf("%s%s", test_string, test_string), - }, - { - test_int * 4, - fmt.Sprintf("%s%s%s", test_string, test_string, test_string), - }, - }) -} - -func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) { - v := testSerialize(t, ctx, val) - for key, value := range(val) { - recreated, exists := v[key] - if exists == false { - t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) - } else if recreated != value { - t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) - } - } - if len(v) != len(val) { - t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) - } -} - -func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) { - v := testSerialize(t, ctx, val) - for i, original := range(val) { - if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) { - t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) - } - for j, o := range(original) { - if v[i][j] != o { - t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) - } - } - } -} - -func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) { - v := testSerialize(t, ctx, val) - for i, original := range(val) { - if v[i] != original { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) - } - } -} - -func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) { - v := testSerialize(t, ctx, val) - if v != val { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) - } -} - -func testSerialize[T any](t *testing.T, ctx *Context, val T) T { - value := reflect.ValueOf(&val).Elem() - type_stack, err := SerializeType(ctx, value.Type()) - chunks, err := SerializeValue(ctx, value) - value_serialized := SerializedValue{type_stack, chunks.Slice()} - fatalErr(t, err) - ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data)) - - value_chunks, err := value_serialized.Chunks() - fatalErr(t, err) - ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice()) - - val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice()) - fatalErr(t, err) - ctx.Log.Logf("test", "Parsed: %+v", val_parsed) - - if len(remaining_parse) != 0 { - t.Fatal("Data remaining after deserializing value") - } - - val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack) - deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data) - fatalErr(t, err) - - if len(remaining_deserialize) != 0 { - t.Fatal("Data remaining after deserializing value") - } else if len(remaining_types) != 0 { - t.Fatal("TypeStack remaining after deserializing value") - } else if val_type != value.Type() { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val))) - } else if deserialized_value.CanConvert(val_type) == false { - t.Fatal("DeserializeValue returned value that can't convert to original value") - } - ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface()) - if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil { - var zero T - return zero - } - return deserialized_value.Interface().(T) -} diff --git a/signal.go b/signal.go index 46acec9..4c33984 100644 --- a/signal.go +++ b/signal.go @@ -7,20 +7,13 @@ import ( "github.com/google/uuid" ) -type SignalDirection uint8 -const ( - Up SignalDirection = iota - Down - Direct -) - type TimeoutSignal struct { ResponseHeader } func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal { return &TimeoutSignal{ - NewResponseHeader(req_id, Direct), + NewResponseHeader(req_id), } } @@ -28,26 +21,23 @@ func (signal TimeoutSignal) String() string { return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader) } -// Timeouts are internal only, no permission allows sending them -func (signal TimeoutSignal) Permission() Tree { - return nil -} +type SignalDirection int +const ( + Up SignalDirection = iota + Down + Direct +) type SignalHeader struct { Id uuid.UUID `gv:"id"` - Dir SignalDirection `gv:"direction"` } func (signal SignalHeader) ID() uuid.UUID { return signal.Id } -func (signal SignalHeader) Direction() SignalDirection { - return signal.Dir -} - func (header SignalHeader) String() string { - return fmt.Sprintf("SignalHeader(%d, %s)", header.Dir, header.Id) + return fmt.Sprintf("SignalHeader(%s)", header.Id) } type ResponseSignal interface { @@ -65,14 +55,12 @@ func (header ResponseHeader) ResponseID() uuid.UUID { } func (header ResponseHeader) String() string { - return fmt.Sprintf("ResponseHeader(%d, %s->%s)", header.Dir, header.Id, header.ReqID) + return fmt.Sprintf("ResponseHeader(%s, %s)", header.Id, header.ReqID) } type Signal interface { fmt.Stringer ID() uuid.UUID - Direction() SignalDirection - Permission() Tree } func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) { @@ -129,16 +117,15 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check return zero, fmt.Errorf("LOOP_ENDED") } -func NewSignalHeader(direction SignalDirection) SignalHeader { +func NewSignalHeader() SignalHeader { return SignalHeader{ uuid.New(), - direction, } } -func NewResponseHeader(req_id uuid.UUID, direction SignalDirection) ResponseHeader { +func NewResponseHeader(req_id uuid.UUID) ResponseHeader { return ResponseHeader{ - NewSignalHeader(direction), + NewSignalHeader(), req_id, } } @@ -151,16 +138,9 @@ func (signal SuccessSignal) String() string { return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader) } -func (signal SuccessSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[ResponseSignal]()): { - SerializedType(SignalTypeFor[SuccessSignal]()): nil, - }, - } -} func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal { return &SuccessSignal{ - NewResponseHeader(req_id, Direct), + NewResponseHeader(req_id), } } @@ -171,16 +151,9 @@ type ErrorSignal struct { func (signal ErrorSignal) String() string { return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error) } -func (signal ErrorSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[ResponseSignal]()): { - SerializedType(SignalTypeFor[ErrorSignal]()): nil, - }, - } -} func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal { return &ErrorSignal{ - NewResponseHeader(req_id, Direct), + NewResponseHeader(req_id), fmt.Sprintf(fmt_string, args...), } } @@ -188,14 +161,9 @@ func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *E type ACLTimeoutSignal struct { ResponseHeader } -func (signal ACLTimeoutSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[ACLTimeoutSignal]()): nil, - } -} func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { sig := &ACLTimeoutSignal{ - NewResponseHeader(req_id, Direct), + NewResponseHeader(req_id), } return sig } @@ -205,17 +173,12 @@ type StatusSignal struct { Source NodeID `gv:"source"` Changes map[ExtType]Changes `gv:"changes"` } -func (signal StatusSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[StatusSignal]()): nil, - } -} func (signal StatusSignal) String() string { return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes) } func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal { return &StatusSignal{ - NewSignalHeader(Up), + NewSignalHeader(), source, changes, } @@ -232,17 +195,9 @@ const ( LinkActionAdd = "ADD" ) -func (signal LinkSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[LinkSignal]()): Tree{ - Hash(LinkActionBase, signal.Action): nil, - }, - } -} - func NewLinkSignal(action string, id NodeID) Signal { return &LinkSignal{ - NewSignalHeader(Direct), + NewSignalHeader(), id, action, } @@ -256,21 +211,9 @@ func (signal LockSignal) String() string { return fmt.Sprintf("LockSignal(%s, %s)", signal.SignalHeader, signal.State) } -const ( - LockStateBase = "LOCK_STATE" -) - -func (signal LockSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[LockSignal]()): Tree{ - Hash(LockStateBase, signal.State): nil, - }, - } -} - func NewLockSignal(state string) *LockSignal { return &LockSignal{ - NewSignalHeader(Direct), + NewSignalHeader(), state, } } @@ -284,21 +227,9 @@ func (signal ReadSignal) String() string { return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions) } -func (signal ReadSignal) Permission() Tree { - ret := Tree{} - for ext, fields := range(signal.Extensions) { - field_tree := Tree{} - for _, field := range(fields) { - field_tree[SerializedType(GetFieldTag(field))] = nil - } - ret[SerializedType(ext)] = field_tree - } - return Tree{SerializedType(SignalTypeFor[ReadSignal]()): ret} -} - func NewReadSignal(exts map[ExtType][]string) *ReadSignal { return &ReadSignal{ - NewSignalHeader(Direct), + NewSignalHeader(), exts, } } @@ -307,23 +238,16 @@ type ReadResultSignal struct { ResponseHeader NodeID NodeID NodeType NodeType - Extensions map[ExtType]map[string]SerializedValue + Extensions map[ExtType]map[string]any } func (signal ReadResultSignal) String() string { return fmt.Sprintf("ReadResultSignal(%s, %s, %+v, %+v)", signal.ResponseHeader, signal.NodeID, signal.NodeType, signal.Extensions) } -func (signal ReadResultSignal) Permission() Tree { - return Tree{ - SerializedType(SignalTypeFor[ResponseSignal]()): { - SerializedType(SignalTypeFor[ReadResultSignal]()): nil, - }, - } -} -func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal { +func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]any) *ReadResultSignal { return &ReadResultSignal{ - NewResponseHeader(req_id, Direct), + NewResponseHeader(req_id), node_id, node_type, exts,