From faab7eb52cf9116a4ade9f4f341bcaf3d1dd5174 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 3 Mar 2024 15:45:45 -0700 Subject: [PATCH] Cleanup(disabled gql test temporarily) --- acl.go | 16 +- acl_test.go | 26 +-- context.go | 262 ++++++++++++++-------------- event.go | 12 +- event_test.go | 5 +- go.mod | 2 +- gql.go | 427 +++------------------------------------------- gql_signal.go | 82 +++------ gql_test.go | 57 +++---- graph_test.go | 6 +- group.go | 28 +-- group_test.go | 6 +- listener.go | 4 - lockable.go | 44 +++-- lockable_test.go | 17 +- node.go | 24 ++- node_test.go | 18 +- policy.go | 4 +- serialize.go | 234 +++++++++---------------- serialize_test.go | 47 +++-- signal.go | 32 ++-- 21 files changed, 432 insertions(+), 921 deletions(-) diff --git a/acl.go b/acl.go index ccb681a..147183a 100644 --- a/acl.go +++ b/acl.go @@ -21,12 +21,12 @@ func NewACLSignal(principal NodeID, action Tree) *ACLSignal { } var DefaultACLPolicy = NewAllNodesPolicy(Tree{ - SerializedType(ACLSignalType): nil, + SerializedType(SignalTypeFor[ACLSignal]()): nil, }) func (signal ACLSignal) Permission() Tree { return Tree{ - SerializedType(ACLSignalType): nil, + SerializedType(SignalTypeFor[ACLSignal]()): nil, } } @@ -51,7 +51,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa var changes = Changes{} info, waiting := ext.Pending[response.ResponseID()] if waiting == true { - changes.Add(ACLExtType, "pending") + AddChange[ACLExt](changes, "pending") delete(ext.Pending, response.ResponseID()) if response.ID() != info.Timeout { err := node.DequeueSignal(info.Timeout) @@ -78,7 +78,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else { if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { - changes.Add(ACLExtType, "pending_acls") + AddChange[ACLExt](changes, "pending_acls") delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed allow") messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID)) @@ -87,7 +87,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) } } else if acl_info.Counter == 0 { - changes.Add(ACLExtType, "pending_acls") + AddChange[ACLExt](changes, "pending_acls") delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed deny") messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied")) @@ -97,7 +97,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else { node.PendingACLs[info.ID] = acl_info - changes.Add(ACLExtType, "pending_acls") + AddChange[ACLExt](changes, "pending_acls") } } } @@ -136,7 +136,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied")) } else if acl_messages != nil { ctx.Log.Logf("acl", "Request pending") - changes.Add(ACLExtType, "pending") + AddChange[ACLExt](changes, "pending") total_messages := 0 // TODO: reasonable timeout/configurable timeout_time := time.Now().Add(time.Second) @@ -175,7 +175,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa acl_info, exists := ext.PendingACLs[sig.ReqID] if exists == true { delete(ext.PendingACLs, sig.ReqID) - changes.Add(ACLExtType, "pending_acls") + AddChange[ACLExt](changes, "pending_acls") ctx.Log.Logf("acl", "Request timeout deny") messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout")) err := node.DequeueSignal(acl_info.TimeoutID) diff --git a/acl_test.go b/acl_test.go index 5497306..2050260 100644 --- a/acl_test.go +++ b/acl_test.go @@ -22,7 +22,7 @@ func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){ } func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){ - acl_node, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt(policies)) + acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies)) fatalErr(t, err) acl_signal := NewACLSignal(listener.ID, action) @@ -42,7 +42,7 @@ func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){ func testSuccess(*SuccessSignal){} func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) { - source_listener, err := GetExt[*ListenerExt](source, ListenerExtType) + source_listener, err := GetExt[ListenerExt](source) fatalErr(t, err) messages := Messages{} @@ -56,26 +56,29 @@ func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *No } func TestACLBasic(t *testing.T) { - ctx := logTestContext(t, []string{"serialize_types", "deserialize_types", "test", "listener_debug", "group", "acl", "policy"}) + ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"}) - listener, err := NewNode(ctx, nil, BaseNodeType, 100, nil, NewListenerExt(100)) + listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100)) fatalErr(t, err) + ctx.Log.Logf("test", "testing fail") testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied")) + ctx.Log.Logf("test", "testing allow all") testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess) - group, err := NewNode(ctx, nil, GroupNodeType, 100, []Policy{ + group, err := NewNode(ctx, nil, "Base", 100, []Policy{ DefaultGroupPolicy, NewPerNodePolicy(map[NodeID]Tree{ listener.ID: { - SerializedType(AddMemberSignalType): nil, - SerializedType(AddSubGroupSignalType): nil, + SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil, + SerializedType(SignalTypeFor[AddMemberSignal]()): nil, }, }), }, NewGroupExt(nil)) fatalErr(t, err) + ctx.Log.Logf("test", "testing empty groups") testSendACL(t, ctx, listener, nil, []Policy{ NewMemberOfPolicy(map[NodeID]map[string]Tree{ group.ID: { @@ -84,14 +87,17 @@ func TestACLBasic(t *testing.T) { }), }, testErrorSignal(t, "acl_denied")) + ctx.Log.Logf("test", "testing adding group") add_subgroup_signal := NewAddSubGroupSignal("test_group") add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group) checkSignal(t, add_subgroup_response, testSuccess) + ctx.Log.Logf("test", "testing adding member") add_member_signal := NewAddMemberSignal("test_group", listener.ID) add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group) checkSignal(t, add_member_response, testSuccess) + ctx.Log.Logf("test", "testing group membership") testSendACL(t, ctx, listener, nil, []Policy{ NewMemberOfPolicy(map[NodeID]map[string]Tree{ group.ID: { @@ -104,21 +110,21 @@ func TestACLBasic(t *testing.T) { NewACLProxyPolicy(nil), }, testErrorSignal(t, "acl_denied")) - acl_proxy_1, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt(nil)) + acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil)) fatalErr(t, err) testSendACL(t, ctx, listener, nil, []Policy{ NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}), }, testErrorSignal(t, "acl_denied")) - acl_proxy_2, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)})) + acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)})) fatalErr(t, err) testSendACL(t, ctx, listener, nil, []Policy{ NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}), }, testSuccess) - acl_proxy_3, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, + acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{ NewMemberOfPolicy(map[NodeID]map[string]Tree{ group.ID: { diff --git a/context.go b/context.go index 91d0118..e99cac2 100644 --- a/context.go +++ b/context.go @@ -9,6 +9,7 @@ import ( "sync" "time" "github.com/google/uuid" + "github.com/graphql-go/graphql" badger "github.com/dgraph-io/badger/v3" ) @@ -19,21 +20,36 @@ var ( ) type ExtensionInfo struct { - Type reflect.Type + Reflect reflect.Type + Interface graphql.Interface + Data interface{} } +type FieldIndex struct { + Extension ExtType + Field string +} + type NodeInfo struct { Extensions []ExtType + Policies []Policy + Fields map[string]FieldIndex } +type GQLValueConverter func(*Context, interface{})(reflect.Value, error) + type TypeInfo struct { Reflect reflect.Type + GQL graphql.Type + Type SerializedType TypeSerialize TypeSerializeFn Serialize SerializeFn TypeDeserialize TypeDeserializeFn Deserialize DeserializeFn + + GQLValue GQLValueConverter } type KindInfo struct { @@ -63,6 +79,8 @@ type Context struct { SignalTypes map[reflect.Type]SignalType // Map between database type hashes and the registered info Nodes map[NodeType]NodeInfo + NodeTypes map[string]NodeType + // Map between go types and registered info Types map[SerializedType]*TypeInfo TypeReflects map[reflect.Type]*TypeInfo @@ -76,7 +94,8 @@ type Context struct { } // Register a NodeType to the context, with the list of extensions it requires -func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { +func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings map[string]FieldIndex) error { + node_type := NodeTypeFor(name, extensions, mappings) _, exists := ctx.Nodes[node_type] if exists == true { return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type) @@ -99,11 +118,17 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e ctx.Nodes[node_type] = NodeInfo{ Extensions: extensions, + Fields: mappings, } + ctx.NodeTypes[name] = node_type + return nil } -func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type PolicyType) error { +func RegisterPolicy[P Policy](ctx *Context) error { + reflect_type := reflect.TypeFor[P]() + policy_type := PolicyTypeFor[P]() + _, exists := ctx.Policies[policy_type] if exists == true { return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type) @@ -114,7 +139,7 @@ func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type Policy return err } - err = ctx.RegisterType(reflect_type, SerializedType(policy_type), nil, SerializeStruct(policy_info), nil, DeserializeStruct(policy_info)) + err = RegisterType[P](ctx, nil, SerializeStruct(policy_info), nil, DeserializeStruct(policy_info)) if err != nil { return err } @@ -126,7 +151,10 @@ func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type Policy return nil } -func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalType) error { +func RegisterSignal[S Signal](ctx *Context) error { + reflect_type := reflect.TypeFor[S]() + signal_type := SignalTypeFor[S]() + _, exists := ctx.Signals[signal_type] if exists == true { return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type) @@ -137,7 +165,7 @@ func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalT return err } - err = ctx.RegisterType(reflect_type, SerializedType(signal_type), nil, SerializeStruct(signal_info), nil, DeserializeStruct(signal_info)) + err = RegisterType[S](ctx, nil, SerializeStruct(signal_info), nil, DeserializeStruct(signal_info)) if err != nil { return err } @@ -149,10 +177,12 @@ func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalT return nil } -func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType, data interface{}) error { +func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data interface{}) error { + reflect_type := reflect.TypeFor[T]() + ext_type := ExtType(SerializedTypeFor[E]()) _, exists := ctx.Extensions[ext_type] if exists == true { - return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type) + return fmt.Errorf("Cannot register extension %+v of type %+v, type already exists in context", reflect_type, ext_type) } elem_type := reflect_type.Elem() @@ -161,7 +191,7 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType return err } - err = ctx.RegisterType(elem_type, SerializedType(ext_type), nil, SerializeStruct(elem_info), nil, DeserializeStruct(elem_info)) + err = RegisterType[E](ctx, nil, SerializeStruct(elem_info), nil, DeserializeStruct(elem_info)) if err != nil { return err } @@ -169,7 +199,7 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type) ctx.Extensions[ext_type] = ExtensionInfo{ - Type: reflect_type, + Reflect: reflect_type, Data: data, } ctx.ExtensionTypes[reflect_type] = ext_type @@ -177,7 +207,8 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType return nil } -func (ctx *Context)RegisterKind(kind reflect.Kind, base reflect.Type, ctx_type SerializedType, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { +func RegisterKind(ctx *Context, kind reflect.Kind, base reflect.Type, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { + ctx_type := SerializedKindFor(kind) _, exists := ctx.Kinds[kind] if exists == true { return fmt.Errorf("Cannot register kind %+v, kind already exists in context", kind) @@ -195,8 +226,8 @@ func (ctx *Context)RegisterKind(kind reflect.Kind, base reflect.Type, ctx_type S info := KindInfo{ Reflect: kind, - Base: base, Type: ctx_type, + Base: base, TypeSerialize: type_serialize, Serialize: serialize, TypeDeserialize: type_deserialize, @@ -210,7 +241,10 @@ func (ctx *Context)RegisterKind(kind reflect.Kind, base reflect.Type, ctx_type S return nil } -func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { +func RegisterType[T any](ctx *Context, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { + reflect_type := reflect.TypeFor[T]() + ctx_type := SerializedTypeFor[T]() + _, exists := ctx.Types[ctx_type] if exists == true { return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type) @@ -263,6 +297,14 @@ func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedTy return nil } +func RegisterStruct[T any](ctx *Context) error { + struct_info, err := GetStructInfo(ctx, reflect.TypeFor[T]()) + if err != nil { + return err + } + return RegisterType[T](ctx, nil, SerializeStruct(struct_info), nil, DeserializeStruct(struct_info)) +} + func (ctx *Context) AddNode(id NodeID, node *Node) { ctx.nodeMapLock.Lock() ctx.nodeMap[id] = node @@ -352,424 +394,388 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Signals: map[SignalType]reflect.Type{}, SignalTypes: map[reflect.Type]SignalType{}, Nodes: map[NodeType]NodeInfo{}, - nodeMap: map[NodeID]*Node{}, + NodeTypes: map[string]NodeType{}, Types: map[SerializedType]*TypeInfo{}, TypeReflects: map[reflect.Type]*TypeInfo{}, Kinds: map[reflect.Kind]*KindInfo{}, KindTypes: map[SerializedType]*KindInfo{}, + + nodeMap: map[NodeID]*Node{}, } var err error - err = ctx.RegisterKind(reflect.Pointer, nil, PointerType, SerializeTypeElem, SerializePointer, DeserializeTypePointer, DeserializePointer) + err = RegisterKind(ctx, reflect.Pointer, nil, SerializeTypeElem, SerializePointer, DeserializeTypePointer, DeserializePointer) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Bool, reflect.TypeOf(true), BoolType, nil, SerializeBool, nil, DeserializeBool[bool]) + err = RegisterKind(ctx, reflect.Bool, reflect.TypeFor[bool](), nil, SerializeBool, nil, DeserializeBool[bool]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.String, reflect.TypeOf(""), StringType, nil, SerializeString, nil, DeserializeString[string]) + err = RegisterKind(ctx, reflect.String, reflect.TypeFor[string](), nil, SerializeString, nil, DeserializeString[string]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Float32, reflect.TypeOf(float32(0)), Float32Type, nil, SerializeFloat32, nil, DeserializeFloat32[float32]) + err = RegisterKind(ctx, reflect.Float32, reflect.TypeFor[float32](), nil, SerializeFloat32, nil, DeserializeFloat32[float32]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Float64, reflect.TypeOf(float64(0)), Float64Type, nil, SerializeFloat64, nil, DeserializeFloat64[float64]) + err = RegisterKind(ctx, reflect.Float64, reflect.TypeFor[float64](), nil, SerializeFloat64, nil, DeserializeFloat64[float64]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint, reflect.TypeOf(uint(0)), UIntType, nil, SerializeUint32, nil, DeserializeUint32[uint]) + err = RegisterKind(ctx, reflect.Uint, reflect.TypeFor[uint](), nil, SerializeUint32, nil, DeserializeUint32[uint]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint8, reflect.TypeOf(uint8(0)), UInt8Type, nil, SerializeUint8, nil, DeserializeUint8[uint8]) + err = RegisterKind(ctx, reflect.Uint8, reflect.TypeFor[uint8](), nil, SerializeUint8, nil, DeserializeUint8[uint8]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint16, reflect.TypeOf(uint16(0)), UInt16Type, nil, SerializeUint16, nil, DeserializeUint16[uint16]) + err = RegisterKind(ctx, reflect.Uint16, reflect.TypeFor[uint16](), nil, SerializeUint16, nil, DeserializeUint16[uint16]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint32, reflect.TypeOf(uint32(0)), UInt32Type, nil, SerializeUint32, nil, DeserializeUint32[uint32]) + err = RegisterKind(ctx, reflect.Uint32, reflect.TypeFor[uint32](), nil, SerializeUint32, nil, DeserializeUint32[uint32]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint64, reflect.TypeOf(uint64(0)), UInt64Type, nil, SerializeUint64, nil, DeserializeUint64[uint64]) + err = RegisterKind(ctx, reflect.Uint64, reflect.TypeFor[uint64](), nil, SerializeUint64, nil, DeserializeUint64[uint64]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int, reflect.TypeOf(int(0)), IntType, nil, SerializeInt32, nil, DeserializeUint32[int]) + err = RegisterKind(ctx, reflect.Int, reflect.TypeFor[int](), nil, SerializeInt32, nil, DeserializeUint32[int]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int8, reflect.TypeOf(int8(0)), Int8Type, nil, SerializeInt8, nil, DeserializeUint8[int8]) + err = RegisterKind(ctx, reflect.Int8, reflect.TypeFor[int8](), nil, SerializeInt8, nil, DeserializeUint8[int8]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int16, reflect.TypeOf(int16(0)), Int16Type, nil, SerializeInt16, nil, DeserializeUint16[int16]) + err = RegisterKind(ctx, reflect.Int16, reflect.TypeFor[int16](), nil, SerializeInt16, nil, DeserializeUint16[int16]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int32, reflect.TypeOf(int32(0)), Int32Type, nil, SerializeInt32, nil, DeserializeUint32[int32]) + err = RegisterKind(ctx, reflect.Int32, reflect.TypeFor[int32](), nil, SerializeInt32, nil, DeserializeUint32[int32]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int64, reflect.TypeOf(int64(0)), Int64Type, nil, SerializeInt64, nil, DeserializeUint64[int64]) + err = RegisterKind(ctx, reflect.Int64, reflect.TypeFor[int64](), nil, SerializeInt64, nil, DeserializeUint64[int64]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(WaitReason("")), WaitReasonType, nil, nil, nil, DeserializeString[WaitReason]) + err = RegisterType[WaitReason](ctx, nil, nil, nil, DeserializeString[WaitReason]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(EventCommand("")), EventCommandType, nil, nil, nil, DeserializeString[EventCommand]) + err = RegisterType[EventCommand](ctx, nil, nil, nil, DeserializeString[EventCommand]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(EventState("")), EventStateType, nil, nil, nil, DeserializeString[EventState]) + err = RegisterType[EventState](ctx, nil, nil, nil, DeserializeString[EventState]) if err != nil { return nil, err } - wait_info_type := reflect.TypeOf(WaitInfo{}) - wait_info_info, err := GetStructInfo(ctx, wait_info_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(wait_info_type, WaitInfoType, nil, SerializeStruct(wait_info_info), nil, DeserializeStruct(wait_info_info)) + err = RegisterStruct[WaitInfo](ctx) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(time.Duration(0)), DurationType, nil, nil, nil, DeserializeUint64[time.Duration]) + err = RegisterType[time.Duration](ctx, nil, nil, nil, DeserializeUint64[time.Duration]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(time.Time{}), TimeType, nil, SerializeGob, nil, DeserializeGob[time.Time]) + err = RegisterType[time.Time](ctx, nil, SerializeGob, nil, DeserializeGob[time.Time]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Map, nil, MapType, SerializeTypeMap, SerializeMap, DeserializeTypeMap, DeserializeMap) + err = RegisterKind(ctx, reflect.Map, nil, SerializeTypeMap, SerializeMap, DeserializeTypeMap, DeserializeMap) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Array, nil, ArrayType, SerializeTypeArray, SerializeArray, DeserializeTypeArray, DeserializeArray) + err = RegisterKind(ctx, reflect.Array, nil, SerializeTypeArray, SerializeArray, DeserializeTypeArray, DeserializeArray) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Slice, nil, SliceType, SerializeTypeElem, SerializeSlice, DeserializeTypeSlice, DeserializeSlice) + err = RegisterKind(ctx, reflect.Slice, nil, SerializeTypeElem, SerializeSlice, DeserializeTypeSlice, DeserializeSlice) if err != nil { return nil, err } - var ptr interface{} = nil - err = ctx.RegisterKind(reflect.Interface, reflect.TypeOf(&ptr).Elem(), InterfaceType, nil, SerializeInterface, nil, DeserializeInterface) + err = RegisterKind(ctx, reflect.Interface, reflect.TypeFor[interface{}](), nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(SerializedType(0)), SerializedTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[SerializedType]) + err = RegisterType[SerializedType](ctx, nil, SerializeUint64, nil, DeserializeUint64[SerializedType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Changes{}), ChangesSerialized, SerializeTypeStub, SerializeMap, DeserializeTypeStub[Changes], DeserializeMap) + err = RegisterType[Changes](ctx, SerializeTypeStub, SerializeMap, DeserializeTypeStub[Changes], DeserializeMap) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[ExtType]) + err = RegisterType[ExtType](ctx, nil, SerializeUint64, nil, DeserializeUint64[ExtType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[NodeType]) + err = RegisterType[NodeType](ctx, nil, SerializeUint64, nil, DeserializeUint64[NodeType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(PolicyType(0)), PolicyTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[PolicyType]) + err = RegisterType[PolicyType](ctx, nil, SerializeUint64, nil, DeserializeUint64[PolicyType]) if err != nil { return nil, err } - node_id_type := reflect.TypeOf(RandID()) - err = ctx.RegisterType(node_id_type, NodeIDType, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[NodeID], DeserializeUUID[NodeID]) + err = RegisterType[NodeID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[NodeID], DeserializeUUID[NodeID]) if err != nil { return nil, err } - uuid_type := reflect.TypeOf(uuid.UUID{}) - err = ctx.RegisterType(uuid_type, UUIDType, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[uuid.UUID], DeserializeUUID[uuid.UUID]) + err = RegisterType[uuid.UUID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[uuid.UUID], DeserializeUUID[uuid.UUID]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, nil, SerializeUint8, nil, DeserializeUint8[SignalDirection]) + err = RegisterType[SignalDirection](ctx, nil, SerializeUint8, nil, DeserializeUint8[SignalDirection]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(ReqState(0)), ReqStateType, nil, SerializeUint8, nil, DeserializeUint8[ReqState]) + err = RegisterType[ReqState](ctx, nil, SerializeUint8, nil, DeserializeUint8[ReqState]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Tree{}), TreeType, SerializeTypeStub, nil, DeserializeTypeStub[Tree], nil) - - var extension Extension = nil - err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtSerialized, nil, SerializeInterface, nil, DeserializeInterface) + err = RegisterType[Tree](ctx, SerializeTypeStub, nil, DeserializeTypeStub[Tree], nil) if err != nil { return nil, err } - var policy Policy = nil - err = ctx.RegisterType(reflect.ValueOf(&policy).Type().Elem(), PolicySerialized, nil, SerializeInterface, nil, DeserializeInterface) + err = RegisterType[Extension](ctx, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } - var signal Signal = nil - err = ctx.RegisterType(reflect.ValueOf(&signal).Type().Elem(), SignalSerialized, nil, SerializeInterface, nil, DeserializeInterface) - if err != nil { - return nil, err - } - - pending_acl_type := reflect.TypeOf(PendingACL{}) - pending_acl_info, err := GetStructInfo(ctx, pending_acl_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(pending_acl_type, PendingACLType, nil, SerializeStruct(pending_acl_info), nil, DeserializeStruct(pending_acl_info)) + err = RegisterType[Policy](ctx, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } - pending_signal_type := reflect.TypeOf(PendingACLSignal{}) - pending_signal_info, err := GetStructInfo(ctx, pending_signal_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(pending_signal_type, PendingACLSignalType, nil, SerializeStruct(pending_signal_info), nil, DeserializeStruct(pending_signal_info)) + err = RegisterType[Signal](ctx, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } - queued_signal_type := reflect.TypeOf(QueuedSignal{}) - queued_signal_info, err := GetStructInfo(ctx, queued_signal_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(queued_signal_type, QueuedSignalType, nil, SerializeStruct(queued_signal_info), nil, DeserializeStruct(queued_signal_info)) + err = RegisterStruct[PendingACL](ctx) if err != nil { return nil, err } - node_type := reflect.TypeOf(Node{}) - node_info, err := GetStructInfo(ctx, node_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(node_type, NodeStructType, nil, SerializeStruct(node_info), nil, DeserializeStruct(node_info)) + err = RegisterStruct[PendingACLSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) + err = RegisterStruct[QueuedSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil) + err = RegisterStruct[Node](ctx) if err != nil { return nil, err } - err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil) + err = RegisterExtension[LockableExt](ctx, nil) if err != nil { return nil, err } - gql_ctx := NewGQLExtContext() - err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx) + err = RegisterExtension[ListenerExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(reflect.TypeOf((*ACLExt)(nil)), ACLExtType, nil) + err = RegisterExtension[GroupExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(reflect.TypeOf((*EventExt)(nil)), EventExtType, nil) + gql_ctx := NewGQLExtContext() + err = RegisterExtension[GQLExt](ctx, gql_ctx) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(OwnerOfPolicy{}), OwnerOfPolicyType) + err = RegisterExtension[ACLExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(ParentOfPolicy{}), ParentOfPolicyType) + err = RegisterExtension[EventExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType) + err = RegisterPolicy[OwnerOfPolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(AllNodesPolicy{}), AllNodesPolicyType) + err = RegisterPolicy[ParentOfPolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(PerNodePolicy{}), PerNodePolicyType) + err = RegisterPolicy[MemberOfPolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterPolicy(reflect.TypeOf(ACLProxyPolicy{}), ACLProxyPolicyType) + err = RegisterPolicy[AllNodesPolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(StoppedSignal{}), StoppedSignalType) + err = RegisterPolicy[PerNodePolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(AddSubGroupSignal{}), AddSubGroupSignalType) + err = RegisterPolicy[ACLProxyPolicy](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(RemoveSubGroupSignal{}), RemoveSubGroupSignalType) + err = RegisterSignal[StoppedSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(ACLTimeoutSignal{}), ACLTimeoutSignalType) + err = RegisterSignal[AddSubGroupSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(ACLSignal{}), ACLSignalType) + err = RegisterSignal[RemoveSubGroupSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(RemoveMemberSignal{}), RemoveMemberSignalType) + err = RegisterSignal[ACLTimeoutSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(AddMemberSignal{}), AddMemberSignalType) + err = RegisterSignal[ACLSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(StopSignal{}), StopSignalType) + err = RegisterSignal[RemoveMemberSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(CreateSignal{}), CreateSignalType) + err = RegisterSignal[AddMemberSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(StartSignal{}), StartSignalType) + err = RegisterSignal[StopSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(StatusSignal{}), StatusSignalType) + err = RegisterSignal[CreateSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(ReadSignal{}), ReadSignalType) + err = RegisterSignal[StartSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(LockSignal{}), LockSignalType) + err = RegisterSignal[StatusSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(TimeoutSignal{}), TimeoutSignalType) + err = RegisterSignal[ReadSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(LinkSignal{}), LinkSignalType) + err = RegisterSignal[LockSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(ErrorSignal{}), ErrorSignalType) + err = RegisterSignal[TimeoutSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(SuccessSignal{}), SuccessSignalType) + err = RegisterSignal[LinkSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(ReadResultSignal{}), ReadResultSignalType) + err = RegisterSignal[ErrorSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(EventControlSignal{}), EventControlSignalType) + err = RegisterSignal[SuccessSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(reflect.TypeOf(EventStateSignal{}), EventStateSignalType) + err = RegisterSignal[ReadResultSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterNodeType(BaseNodeType, []ExtType{}) + err = RegisterSignal[EventControlSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterNodeType(GroupNodeType, []ExtType{GroupExtType}) + err = RegisterSignal[EventStateSignal](ctx) if err != nil { return nil, err } - err = ctx.RegisterNodeType(GQLNodeType, []ExtType{GQLExtType}) + err = RegisterNodeType(ctx, "Base", []ExtType{}, map[string]FieldIndex{}) if err != nil { return nil, err } diff --git a/event.go b/event.go index 0f99a1c..d101cf7 100644 --- a/event.go +++ b/event.go @@ -21,7 +21,7 @@ func NewParentOfPolicy(policy Tree) *ParentOfPolicy { } func (policy ParentOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - event_ext, err := GetExt[*EventExt](node, EventExtType) + event_ext, err := GetExt[EventExt](node) if err != nil { ctx.Log.Logf("event", "ParentOfPolicy, node not event %s", node.ID) return nil, Deny @@ -39,7 +39,7 @@ func (policy ParentOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si } var DefaultEventPolicy = NewParentOfPolicy(Tree{ - SerializedType(EventControlSignalType): nil, + SerializedType(SignalTypeFor[EventControlSignal]()): nil, }) type EventExt struct { @@ -66,7 +66,7 @@ type EventStateSignal struct { func (signal EventStateSignal) Permission() Tree { return Tree{ - SerializedType(StatusType): nil, + SerializedType(SignalTypeFor[StatusSignal]()): nil, } } @@ -101,7 +101,7 @@ func NewEventControlSignal(command EventCommand) *EventControlSignal { func (signal EventControlSignal) Permission() Tree { return Tree{ - SerializedType(EventControlSignalType): { + SerializedType(SignalTypeFor[EventControlSignal]()): { Hash("command", string(signal.Command)): nil, }, } @@ -110,7 +110,7 @@ func (signal EventControlSignal) Permission() Tree { func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) { if ext.State != state { ext.StateStart = state_start - changes.Add(EventExtType, "state") + AddChange[EventExt](changes, "state") ext.State = state node.QueueSignal(time.Now(), NewEventStateSignal(node.ID, ext.State, time.Now())) } @@ -157,7 +157,7 @@ func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal switch sig := signal.(type) { case *EventControlSignal: - event_ext, err := GetExt[*EventExt](node, EventExtType) + event_ext, err := GetExt[EventExt](node) if err != nil { messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_event")) } else { diff --git a/event_test.go b/event_test.go index c6f3d66..1f06104 100644 --- a/event_test.go +++ b/event_test.go @@ -2,7 +2,6 @@ package graphvent import ( "crypto/ed25519" - "reflect" "testing" "time" "crypto/rand" @@ -10,13 +9,13 @@ import ( func TestEvent(t *testing.T) { ctx := logTestContext(t, []string{"event", "listener", "listener_debug"}) - err := ctx.RegisterExtension(reflect.TypeOf(&TestEventExt{}), NewExtType("TEST_EVENT"), nil) + err := RegisterExtension[TestEventExt](ctx, nil) fatalErr(t, err) event_public, event_private, err := ed25519.GenerateKey(rand.Reader) event_listener := NewListenerExt(100) - event, err := NewNode(ctx, event_private, BaseNodeType, 100, nil, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener) + event, err := NewNode(ctx, event_private, "Base", 100, nil, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener) fatalErr(t, err) response, signals := testSend(t, ctx, NewEventControlSignal("ready?"), event, event) diff --git a/go.mod b/go.mod index 65ed97f..f28c45b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/mekkanized/graphvent -go 1.21.0 +go 1.22.0 require ( github.com/dgraph-io/badger/v3 v3.2103.5 diff --git a/gql.go b/gql.go index aa108c3..9414d77 100644 --- a/gql.go +++ b/gql.go @@ -540,7 +540,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveCon Ext: gql_ext, Chans: map[uuid.UUID]chan Signal{}, Context: ctx, - GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext), + GQLContext: ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext), NodeCache: map[NodeID]NodeResult{}, Server: server, Authorization: nil, @@ -585,7 +585,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW query := GQLPayload{} json.Unmarshal(str, &query) - gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext) + gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, @@ -739,7 +739,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon } } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) - gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext) + gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, Context: req_ctx, @@ -843,10 +843,6 @@ type GQLExtContext struct { Types []graphql.Type Query *graphql.Object Mutation *graphql.Object - Subscription *graphql.Object - - TypeMap map[reflect.Type]GQLTypeInfo - KindMap map[reflect.Kind]GQLTypeInfo } func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) { @@ -878,7 +874,6 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { Types: ctx.Types, Query: ctx.Query, Mutation: ctx.Mutation, - Subscription: ctx.Subscription, } return graphql.NewSchema(schemaConfig) @@ -920,7 +915,7 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, return nil, fmt.Errorf("%s is not in the fields of %+v in the result for %s - %+v", gv_tag, ext_type, gql_name, node) } - if val_ser.TypeStack[0] == ErrorType { + if val_ser.TypeStack[0] == SerializedTypeFor[error]() { return nil, fmt.Errorf(string(val_ser.Data)) } @@ -1151,421 +1146,39 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte func NewGQLExtContext() *GQLExtContext { query := graphql.NewObject(graphql.ObjectConfig{ Name: "Query", - Fields: graphql.Fields{}, + Fields: graphql.Fields{ + "Test": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return "Test Data", nil + }, + }, + }, }) mutation := graphql.NewObject(graphql.ObjectConfig{ Name: "Mutation", - Fields: graphql.Fields{}, - }) - - subscription := graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: graphql.Fields{}, - }) - - kind_map := map[reflect.Kind]GQLTypeInfo{ - reflect.String: { - func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) { - return graphql.String, nil - }, - func(ctx *GQLExtContext, value interface{})(reflect.Value, error) { - return reflect.ValueOf(value), nil - }, - }, - reflect.Bool: { - func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) { - return graphql.Boolean, nil - }, - func(ctx *GQLExtContext, value interface{})(reflect.Value, error) { - return reflect.ValueOf(value), nil - }, - }, - } - type_map := map[reflect.Type]GQLTypeInfo{ - reflect.TypeOf(EventCommand("")): { - func (ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) { - return graphql.String, nil - }, - func(ctx *GQLExtContext, value interface{})(reflect.Value, error) { - return reflect.ValueOf(EventCommand(value.(string))), nil + Fields: graphql.Fields{ + "Test": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return "Test Mutation Data", nil + }, }, }, - reflect.TypeOf([2]NodeID{}): { - func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) { - return graphql.NewList(graphql.String), nil - }, - func(ctx *GQLExtContext, value interface{})(reflect.Value, error) { - l, ok := value.([]interface{}) - if ok == false { - return reflect.Value{}, fmt.Errorf("not list: %s", reflect.TypeOf(value)) - } else if len(l) != 2 { - return reflect.Value{}, fmt.Errorf("wrong length: %d/2", len(l)) - } - - id1_str, ok := l[0].(string) - if ok == false { - return reflect.Value{}, fmt.Errorf("not strg: %s", reflect.TypeOf(l[0])) - } - id1, err := ParseID(id1_str) - if err != nil { - return reflect.Value{}, err - } - id2_str, ok := l[1].(string) - if ok == false { - return reflect.Value{}, fmt.Errorf("not strg: %s", reflect.TypeOf(l[1])) - } - id2, err := ParseID(id2_str) - if err != nil { - return reflect.Value{}, err - } - return_value := reflect.New(reflect.TypeOf([2]NodeID{})).Elem() - return_value.Index(0).Set(reflect.ValueOf(id1)) - return_value.Index(1).Set(reflect.ValueOf(id2)) + }) - return return_value, nil - }, - }, - reflect.TypeOf(time.Time{}): { - func(ctx *GQLExtContext, reflect_type reflect.Type) (graphql.Type, error) { - return graphql.DateTime, nil - }, - func(ctx *GQLExtContext, value interface{}) (reflect.Value, error) { - return reflect.ValueOf(value), nil - }, - }, - reflect.TypeOf(&NodeID{}): { - func(ctx *GQLExtContext, reflect_type reflect.Type) (graphql.Type, error) { - return graphql.String, nil - }, - func(ctx *GQLExtContext, value interface{}) (reflect.Value, error) { - str, ok := value.(string) - if ok == false { - return reflect.Value{}, fmt.Errorf("value is not string") - } - - if str == "" { - return reflect.New(reflect.TypeOf(&NodeID{})).Elem(), nil - } - - id_parsed, err := ParseID(str) - if err != nil { - return reflect.Value{}, err - } - - return reflect.ValueOf(&id_parsed), nil - }, - }, - reflect.TypeOf(NodeID{}): { - func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) { - return graphql.String, nil - }, - func(ctx *GQLExtContext, value interface{})(reflect.Value, error) { - str, ok := value.(string) - if ok == false { - return reflect.Value{}, fmt.Errorf("value is not string") - } - - id_parsed, err := ParseID(str) - if err != nil { - return reflect.Value{}, err - } - - return reflect.ValueOf(id_parsed), nil - }, - }, - } context := GQLExtContext{ Schema: graphql.Schema{}, Types: []graphql.Type{}, Query: query, Mutation: mutation, - Subscription: subscription, NodeTypes: map[NodeType]*graphql.Object{}, Interfaces: map[string]*Interface{}, Fields: map[string]Field{}, - KindMap: kind_map, - TypeMap: type_map, } - var err error - err = context.RegisterInterface("Node", "DefaultNode", []string{}, []string{}, map[string]SelfField{}, map[string]ListField{}) - if err != nil { - panic(err) - } - - err = context.RegisterField(graphql.String, "EventName", EventExtType, "name", func(p graphql.ResolveParams, ctx *ResolveContext, val reflect.Value)(interface{}, error) { - name := val.String() - return name, nil - }) - - err = context.RegisterField(graphql.String, "EventStateStart", EventExtType, "state_start", func(p graphql.ResolveParams, ctx *ResolveContext, val reflect.Value)(interface{}, error) { - state_start := val.Interface().(time.Time) - return state_start, nil - }) - - err = context.RegisterField(graphql.String, "EventState", EventExtType, "state", func(p graphql.ResolveParams, ctx *ResolveContext, val reflect.Value)(interface{}, error) { - state := val.String() - return state, nil - }) - - err = context.RegisterInterface("Event", "EventNode", []string{"Node"}, []string{"EventName", "EventStateStart", "EventState"}, map[string]SelfField{}, map[string]ListField{}) - if err != nil { - panic(err) - } - - sub_group_type := graphql.NewObject(graphql.ObjectConfig{ - Name: "SubGroup", - Interfaces: nil, - Fields: graphql.Fields{ - "Name": &graphql.Field{ - Type: graphql.String, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - val, ok := p.Source.(SubGroupGQL) - if ok == false { - return nil, fmt.Errorf("WRONG_TYPE_RETURNED") - } - return val.Name, nil - }, - }, - "Members": &graphql.Field{ - Type: context.Interfaces["Node"].List, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - val, ok := p.Source.(SubGroupGQL) - if ok == false { - return nil, fmt.Errorf("WRONG_TYPE_RETURNED") - } - - nodes, err := ResolveNodes(ctx, p, val.Members) - if err != nil { - return nil, err - } - - return nodes, nil - }, - }, - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - return reflect.TypeOf(p.Value) == reflect.TypeOf(SubGroupGQL{}) - }, - Description: "SubGroup within Group", - }) - context.Types = append(context.Types, sub_group_type) - - err = context.RegisterField(sub_group_type, "SubGroups", GroupExtType, "sub_groups", - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { - node_map, ok := value.Interface().(map[string]SubGroup) - if ok == false { - return nil, fmt.Errorf("value is %+v, not map[string]SubGroup", value.Type()) - } - - sub_groups := []SubGroupGQL{} - for name, sub_group := range(node_map) { - sub_groups = append(sub_groups, SubGroupGQL{ - name, - sub_group.Members, - }) - } - - return sub_groups, nil - }) - if err != nil { - panic(err) - } - - err = context.RegisterInterface("Group", "DefaultGroup", []string{"Node"}, []string{"SubGroups"}, map[string]SelfField{}, map[string]ListField{}) - if err != nil { - panic(err) - } - - err = context.RegisterField(graphql.String, "LockableState", LockableExtType, "state", - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { - state, ok := value.Interface().(ReqState) - if ok == false { - return nil, fmt.Errorf("value is %+v, not ReqState", value.Type()) - } - - return ReqStateStrings[state], nil - }) - if err != nil { - panic(err) - } - - err = context.RegisterInterface("Lockable", "DefaultLockable", []string{"Node"}, []string{"LockableState"}, map[string]SelfField{ - "Owner": { - "owner", - LockableExtType, - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (*NodeID, error) { - id, ok := value.Interface().(*NodeID) - if ok == false { - return nil, fmt.Errorf("can't parse %+v as *NodeID", value.Type()) - } - - return id, nil - }, - }, - }, map[string]ListField{ - "Requirements": { - "requirements", - LockableExtType, - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) ([]NodeID, error) { - id_strs, ok := value.Interface().(map[NodeID]ReqState) - if ok == false { - return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState", value.Type()) - } - - ids := []NodeID{} - for id := range(id_strs) { - ids = append(ids, id) - } - return ids, nil - }, - }, - }) - - if err != nil { - panic(err) - } - - err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (interface{}, error) { - return value.String(), nil - }) - if err != nil { - panic(err) - } - - err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable", "Group"}, []string{"LockableState", "Listen", "Owner", "Requirements", "SubGroups"}) - if err != nil { - panic(err) - } - - err = context.AddSignalMutation("stop", "node_id", reflect.TypeOf(StopSignal{})) - if err != nil { - panic(err) - } - - err = context.AddSignalMutation("addMember", "group_id", reflect.TypeOf(AddMemberSignal{})) - if err != nil { - panic(err) - } - - err = context.AddSignalMutation("removeMember", "group_id", reflect.TypeOf(RemoveMemberSignal{})) - if err != nil { - panic(err) - } - - err = context.AddSignalMutation("eventControl", "event_id", reflect.TypeOf(EventControlSignal{})) - if err != nil { - panic(err) - } - - context.Subscription.AddFieldConfig("Self", &graphql.Field{ - Type: context.Interfaces["Node"].Interface, - Subscribe: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - c, err := ctx.Ext.AddSubscription(ctx.ID) - if err != nil { - return nil, err - } - - nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of nodes returned") - } - - c <- nodes[0] - - return c, nil - }, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_RESOLVE: %+v", p.Source) - - switch source := p.Source.(type) { - case NodeResult: - case *StatusSignal: - delete(ctx.NodeCache, source.Source) - ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source) - if source.Source == ctx.Server.ID { - nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of nodes returned") - } - ctx.NodeCache[ctx.Server.ID] = nodes[0] - } - default: - return nil, fmt.Errorf("Don't know how to handle %+v", source) - } - - return ctx.NodeCache[ctx.Server.ID], nil - }, - }) - - context.Query.AddFieldConfig("Self", &graphql.Field{ - Type: context.Interfaces["Node"].Interface, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of resolved nodes returned") - } - - return nodes[0], nil - }, - }) - - context.Query.AddFieldConfig("Node", &graphql.Field{ - Type: context.Interfaces["Node"].Interface, - Args: graphql.FieldConfigArgument{ - "id": &graphql.ArgumentConfig{ - Type: graphql.String, - }, - }, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - id, err := ExtractID(p, "id") - if err != nil { - return nil, err - } - - nodes, err := ResolveNodes(ctx, p, []NodeID{id}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of resolved nodes returned") - } - - return nodes[0], nil - }, - }) - schema, err := BuildSchema(&context) if err != nil { panic(err) @@ -1725,7 +1338,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa err := ext.StartGQLServer(ctx, node) if err == nil { ctx.Log.Logf("gql", "started gql server on %s", ext.Listen) - changes.Add(GQLExtType, "state") + AddChange[GQLExt](changes, "state") } else { ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } diff --git a/gql_signal.go b/gql_signal.go index 4c5899b..1f5382c 100644 --- a/gql_signal.go +++ b/gql_signal.go @@ -8,35 +8,26 @@ import ( "time" ) -type GQLTypeConverter func(*GQLExtContext, reflect.Type)(graphql.Type, error) -type GQLValueConverter func(*GQLExtContext, interface{})(reflect.Value, error) -type GQLTypeInfo struct { - Type GQLTypeConverter - Value GQLValueConverter -} - -func GetGQLTypeInfo(ctx *GQLExtContext, reflect_type reflect.Type) (*GQLTypeInfo, error) { - type_info, type_mapped := ctx.TypeMap[reflect_type] - if type_mapped == false { - kind_info, kind_mapped := ctx.KindMap[reflect_type.Kind()] - if kind_mapped == false { - return nil, fmt.Errorf("Signal has unsupported type/kind: %s/%s", reflect_type, reflect_type.Kind()) - } else { - return &kind_info, nil - } - } else { - return &type_info, nil - } -} - type StructFieldInfo struct { Name string - Type reflect.Type - GQL *GQLTypeInfo + Type *TypeInfo Index []int } -func SignalFromArgs(ctx *GQLExtContext, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) { +func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) { + type_info, mapped := ctx.TypeReflects[field.Type] + if mapped == false { + return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type) + } + + return StructFieldInfo{ + Name: gv_tag, + Type: type_info, + Index: field.Index, + }, nil +} + +func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) { fmt.Printf("FIELD: %+v\n", fields) signal_value := reflect.New(signal_type) @@ -52,10 +43,10 @@ func SignalFromArgs(ctx *GQLExtContext, signal_type reflect.Type, fields []Struc return nil, fmt.Errorf("No arg provided named %s", field.Name) } field_value := signal_value.Elem().FieldByIndex(field.Index) - if field_value.CanConvert(field.Type) == false { - return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type) + if field_value.CanConvert(field.Type.Reflect) == false { + return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect) } - value, err := field.GQL.Value(ctx, arg) + value, err := field.Type.GQLValue(ctx, arg) if err != nil { return nil, err } @@ -65,26 +56,7 @@ func SignalFromArgs(ctx *GQLExtContext, signal_type reflect.Type, fields []Struc return signal_value.Interface().(Signal), nil } -func ArgumentInfo(ctx *GQLExtContext, field reflect.StructField, gv_tag string) (*graphql.ArgumentConfig, StructFieldInfo, error) { - gql_info, err := GetGQLTypeInfo(ctx, field.Type) - if err != nil { - return nil, StructFieldInfo{}, err - } - gql_type, err := gql_info.Type(ctx, field.Type) - if err != nil { - return nil, StructFieldInfo{}, err - } - return &graphql.ArgumentConfig{ - Type: gql_type, - }, StructFieldInfo{ - gv_tag, - field.Type, - gql_info, - field.Index, - }, nil -} - -func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, signal_type reflect.Type) error { +func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) { args := graphql.FieldConfigArgument{} arg_info := []StructFieldInfo{} var id_index []int = nil @@ -100,13 +72,14 @@ func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, sig } else { _, exists := args[gv_tag] if exists == true { - return fmt.Errorf("Signal has repeated tag %s", gv_tag) + return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag) } else { - config, info, err := ArgumentInfo(ext, field, gv_tag) + info, err := ArgumentInfo(ctx, field, gv_tag) if err != nil { - return err + return nil, err + } + args[gv_tag] = &graphql.ArgumentConfig{ } - args[gv_tag] = config arg_info = append(arg_info, info) } } @@ -131,7 +104,7 @@ func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, sig return nil, err } - signal, err := SignalFromArgs(ctx.GQLContext, signal_type, arg_info, p.Args, id_index, direction_index) + signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index) if err != nil { return nil, err } @@ -164,10 +137,9 @@ func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, sig return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response)) } - ext.Mutation.AddFieldConfig(name, &graphql.Field{ + return &graphql.Field{ Type: graphql.String, Args: args, Resolve: resolve_signal, - }) - return nil + }, nil } diff --git a/gql_test.go b/gql_test.go index b5f5ac4..9b3b131 100644 --- a/gql_test.go +++ b/gql_test.go @@ -1,6 +1,6 @@ package graphvent -import ( +/*import ( "testing" "time" "fmt" @@ -20,11 +20,11 @@ func TestGQLAuth(t *testing.T) { ctx := logTestContext(t, []string{"test"}) listener_1 := NewListenerExt(10) - node_1, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_1) + node_1, err := NewNode(ctx, nil, "Base", 10, nil, listener_1) fatalErr(t, err) listener_2 := NewListenerExt(10) - node_2, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_2) + node_2, err := NewNode(ctx, nil, "Base", 10, nil, listener_2) fatalErr(t, err) auth_header, err := AuthB64(node_1.Key, node_2.Key.Public().(ed25519.PublicKey)) @@ -43,48 +43,44 @@ func TestGQLAuth(t *testing.T) { } func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "deserialize_types", "serialize_types", "gqlws", "gql"}) - - TestNodeType := NewNodeType("TEST") - err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) - fatalErr(t, err) + ctx := logTestContext(t, []string{"test", "gqlws", "gql"}) pub, gql_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) gql_id := KeyID(pub) group_policy_1 := NewAllNodesPolicy(Tree{ - SerializedType(ReadSignalType): Tree{ - SerializedType(GroupExtType): Tree{ - Hash(FieldNameBase, "members"): Tree{}, + SerializedType(SignalTypeFor[ReadSignal]()): Tree{ + SerializedType(ExtTypeFor[GroupExt]()): Tree{ + SerializedType(GetFieldTag("members")): Tree{}, }, }, - SerializedType(ReadResultSignalType): nil, - SerializedType(ErrorSignalType): nil, + SerializedType(SignalTypeFor[ReadResultSignal]()): nil, + SerializedType(SignalTypeFor[ErrorSignal]()): nil, }) group_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{ gql_id: { "test_group": { - SerializedType(LinkSignalType): nil, - SerializedType(LockSignalType): nil, - SerializedType(StatusSignalType): nil, - SerializedType(ReadSignalType): nil, + SerializedType(SignalTypeFor[LinkSignal]()): nil, + SerializedType(SignalTypeFor[LockSignal]()): nil, + SerializedType(SignalTypeFor[StatusSignal]()): nil, + SerializedType(SignalTypeFor[ReadSignal]()): nil, }, }, }) user_policy_1 := NewAllNodesPolicy(Tree{ - SerializedType(ReadResultSignalType): nil, - SerializedType(ErrorSignalType): nil, + SerializedType(SignalTypeFor[ReadResultSignal]()): nil, + SerializedType(SignalTypeFor[ErrorSignal]()): nil, }) user_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{ gql_id: { "test_group": { - SerializedType(LinkSignalType): nil, - SerializedType(ReadSignalType): nil, - SerializedType(LockSignalType): nil, + SerializedType(SignalTypeFor[LinkSignal]()): nil, + SerializedType(SignalTypeFor[ReadSignal]()): nil, + SerializedType(SignalTypeFor[LockSignal]()): nil, }, }, }) @@ -93,10 +89,10 @@ func TestGQLServer(t *testing.T) { fatalErr(t, err) listener_ext := NewListenerExt(10) - n1, err := NewNode(ctx, nil, TestNodeType, 10, []Policy{user_policy_2, user_policy_1}, NewLockableExt(nil)) + n1, err := NewNode(ctx, nil, "Base", 10, []Policy{user_policy_2, user_policy_1}, NewLockableExt(nil)) fatalErr(t, err) - gql, err := NewNode(ctx, gql_key, GQLNodeType, 10, []Policy{group_policy_2, group_policy_1}, + gql, err := NewNode(ctx, gql_key, "Base", 10, []Policy{group_policy_2, group_policy_1}, NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[string][]NodeID{"test_group": {n1.ID, gql_id}}), listener_ext) fatalErr(t, err) @@ -219,7 +215,7 @@ func TestGQLServer(t *testing.T) { msgs := Messages{} test_changes := Changes{} - test_changes.Add(GQLExtType, "state") + AddChange[GQLExt](test_changes, "state") msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, test_changes)) err = ctx.Send(msgs) fatalErr(t, err) @@ -246,10 +242,7 @@ func TestGQLServer(t *testing.T) { func TestGQLDB(t *testing.T) { ctx := logTestContext(t, []string{"test", "db", "node"}) - TestUserNodeType := NewNodeType("TEST_USER") - err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) - fatalErr(t, err) - u1, err := NewNode(ctx, nil, TestUserNodeType, 10, nil) + u1, err := NewNode(ctx, nil, "Base", 10, nil) fatalErr(t, err) ctx.Log.Logf("test", "U1_ID: %s", u1.ID) @@ -257,7 +250,7 @@ func TestGQLDB(t *testing.T) { gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) listener_ext := NewListenerExt(10) - gql, err := NewNode(ctx, nil, GQLNodeType, 10, nil, + gql, err := NewNode(ctx, nil, "Base", 10, nil, gql_ext, listener_ext, NewGroupExt(nil)) @@ -278,7 +271,7 @@ func TestGQLDB(t *testing.T) { gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) - listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType) + listener_ext, err = GetExt[ListenerExt](gql_loaded) fatalErr(t, err) msgs = Messages{} msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded, nil, NewStopSignal()) @@ -289,4 +282,4 @@ func TestGQLDB(t *testing.T) { }) fatalErr(t, err) } - +*/ diff --git a/graph_test.go b/graph_test.go index 04fa6e6..499f9d4 100644 --- a/graph_test.go +++ b/graph_test.go @@ -6,13 +6,11 @@ import ( badger "github.com/dgraph-io/badger/v3" ) -var SimpleListenerNodeType = NewNodeType("SIMPLE_LISTENER") - func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) { listener_extension := NewListenerExt(buffer) listener, err := NewNode(ctx, nil, - SimpleListenerNodeType, + "LockableListener", 10, nil, listener_extension, @@ -30,7 +28,7 @@ func logTestContext(t * testing.T, components []string) *Context { ctx, err := NewContext(db, NewConsoleLogger(components)) fatalErr(t, err) - err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ListenerExtType, LockableExtType}) + err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()}, map[string]FieldIndex{}) fatalErr(t, err) return ctx diff --git a/group.go b/group.go index ae19368..dcfa14f 100644 --- a/group.go +++ b/group.go @@ -18,7 +18,7 @@ func NewAddSubGroupSignal(name string) *AddSubGroupSignal { func (signal AddSubGroupSignal) Permission() Tree { return Tree{ - SerializedType(AddSubGroupSignalType): { + SerializedType(SignalTypeFor[AddSubGroupSignal]()): { Hash("name", signal.Name): nil, }, } @@ -38,7 +38,7 @@ func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal { func (signal RemoveSubGroupSignal) Permission() Tree { return Tree{ - SerializedType(RemoveSubGroupSignalType): { + SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): { Hash("command", signal.Name): nil, }, } @@ -57,7 +57,7 @@ type SubGroupGQL struct { func (signal AddMemberSignal) Permission() Tree { return Tree{ - SerializedType(AddMemberSignalType): { + SerializedType(SignalTypeFor[AddMemberSignal]()): { Hash("sub_group", signal.SubGroup): nil, }, } @@ -79,7 +79,7 @@ type RemoveMemberSignal struct { func (signal RemoveMemberSignal) Permission() Tree { return Tree{ - SerializedType(RemoveMemberSignalType): { + SerializedType(SignalTypeFor[RemoveMemberSignal]()): { Hash("sub_group", signal.SubGroup): nil, }, } @@ -94,9 +94,9 @@ func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSign } var DefaultGroupPolicy = NewAllNodesPolicy(Tree{ - SerializedType(ReadSignalType): { - SerializedType(GroupExtType): { - Hash(FieldNameBase, "sub_groups"): nil, + SerializedType(SignalTypeFor[ReadSignal]()): { + SerializedType(ExtTypeFor[GroupExt]()): { + SerializedType(GetFieldTag("sub_groups")): nil, }, }, }) @@ -125,7 +125,7 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si } ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions) - group_ext_data, ok := sig.Extensions[GroupExtType] + group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()] if ok == false { return Deny } @@ -178,7 +178,7 @@ func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tr var messages Messages = nil for group_id, sub_groups := range(policy.Groups) { if group_id == node.ID { - ext, err := GetExt[*GroupExt](node, GroupExtType) + ext, err := GetExt[GroupExt](node) if err != nil { ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err) } else { @@ -199,7 +199,7 @@ func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tr } else { // Send the read request to the group so that ContinueAllows can parse the response to check membership messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{ - GroupExtType: {"sub_groups"}, + ExtTypeFor[GroupExt](): {"sub_groups"}, })) } } @@ -240,7 +240,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig ext.SubGroups[sig.SubGroup] = sub_group messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes.Add(GroupExtType, "sub_groups") + AddChange[GroupExt](changes, "sub_groups") } } @@ -257,7 +257,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig ext.SubGroups[sig.SubGroup] = sub_group messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes.Add(GroupExtType, "sub_groups") + AddChange[GroupExt](changes, "sub_groups") } } @@ -268,7 +268,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig } else { ext.SubGroups[sig.Name] = []NodeID{} - changes.Add(GroupExtType, "sub_groups") + AddChange[GroupExt](changes, "sub_groups") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) } case *RemoveSubGroupSignal: @@ -278,7 +278,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig } else { delete(ext.SubGroups, sig.Name) - changes.Add(GroupExtType, "sub_groups") + AddChange[GroupExt](changes, "sub_groups") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) } } diff --git a/group_test.go b/group_test.go index 115674a..209f7fb 100644 --- a/group_test.go +++ b/group_test.go @@ -9,7 +9,7 @@ func TestGroupAdd(t *testing.T) { ctx := logTestContext(t, []string{"listener", "test"}) group_listener := NewListenerExt(10) - group, err := NewNode(ctx, nil, GroupNodeType, 10, nil, group_listener, NewGroupExt(nil)) + group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil)) fatalErr(t, err) add_subgroup_signal := NewAddSubGroupSignal("test_group") @@ -41,7 +41,7 @@ func TestGroupAdd(t *testing.T) { } read_signal := NewReadSignal(map[ExtType][]string{ - GroupExtType: {"sub_groups"}, + ExtTypeFor[GroupExt](): {"sub_groups"}, }) messages = Messages{} @@ -53,7 +53,7 @@ func TestGroupAdd(t *testing.T) { read_response := response.(*ReadResultSignal) - sub_groups_serialized := read_response.Extensions[GroupExtType]["sub_groups"] + sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"] sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack) fatalErr(t, err) diff --git a/listener.go b/listener.go index 2ee70fd..2d88c5e 100644 --- a/listener.go +++ b/listener.go @@ -23,10 +23,6 @@ func NewListenerExt(buffer int) *ListenerExt { } } -func (listener *ListenerExt) Type() ExtType { - return ListenerExtType -} - // Send the signal to the channel, logging an overflow if it occurs func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal)) diff --git a/lockable.go b/lockable.go index 25c8a96..86daa4a 100644 --- a/lockable.go +++ b/lockable.go @@ -6,13 +6,13 @@ import ( ) var AllowParentUnlockPolicy = NewOwnerOfPolicy(Tree{ - SerializedType(LockSignalType): { + SerializedType(SignalTypeFor[LockSignal]()): { Hash(LockStateBase, "unlock"): nil, }, }) var AllowAnyLockPolicy = NewAllNodesPolicy(Tree{ - SerializedType(LockSignalType): { + SerializedType(SignalTypeFor[LockSignal]()): { Hash(LockStateBase, "lock"): nil, }, }) @@ -44,10 +44,6 @@ type LockableExt struct{ WaitInfos WaitMap `gv:"wait_infos"` } -func (ext *LockableExt) Type() ExtType { - return LockableExtType -} - func NewLockableExt(requirements []NodeID) *LockableExt { var reqs map[NodeID]ReqState = nil if requirements != nil { @@ -87,7 +83,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI if info_found { state, found := ext.Requirements[info.Destination] if found == true { - changes.Add(LockableExtType, "wait_infos") + AddChange[LockableExt](changes, "wait_infos") ctx.Log.Logf("lockable", "got mapped response %+v for %+v in state %s while in %s", signal, info, ReqStateStrings[state], ReqStateStrings[ext.State]) switch ext.State { case AbortingLock: @@ -100,11 +96,11 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI } } if all_unlocked == true { - changes.Add(LockableExtType, "state") + AddChange[LockableExt](changes, "state") ext.State = Unlocked } case Locking: - changes.Add(LockableExtType, "state") + AddChange[LockableExt](changes, "state") ext.Requirements[info.Destination] = Unlocked unlocked := 0 for _, state := range(ext.Requirements) { @@ -164,7 +160,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID ext.Requirements = map[NodeID]ReqState{} } ext.Requirements[signal.NodeID] = Unlocked - changes.Add(LockableExtType, "requirements") + AddChange[LockableExt](changes, "requirements") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } case "remove": @@ -173,7 +169,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement")) } else { delete(ext.Requirements, signal.NodeID) - changes.Add(LockableExtType, "requirements") + AddChange[LockableExt](changes, "requirements") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } default: @@ -214,10 +210,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) ext.State = Locked ext.Owner = ext.PendingOwner - changes.Add(LockableExtType, "state", "owner", "requirements") + AddChange[LockableExt](changes, "state", "owner", "requirements") messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) } else { - changes.Add(LockableExtType, "requirements") + AddChange[LockableExt](changes, "requirements") ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements)) } case AbortingLock: @@ -250,15 +246,15 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod previous_owner := *ext.Owner ext.Owner = ext.PendingOwner ext.ReqID = nil - changes.Add(LockableExtType, "state", "owner", "req_id") + AddChange[LockableExt](changes, "state", "owner", "req_id") messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) } else if old_state == AbortingLock { - changes.Add(LockableExtType, "state", "pending_owner") + AddChange[LockableExt](changes, "state", "pending_owner") messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) ext.PendingOwner = ext.Owner } } else { - changes.Add(LockableExtType, "state") + AddChange[LockableExt](changes, "state") ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, len(ext.Requirements)) } } @@ -282,7 +278,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = &new_owner ext.Owner = &new_owner - changes.Add(LockableExtType, "state", "pending_owner", "owner") + AddChange[LockableExt](changes, "state", "pending_owner", "owner") messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else { ext.State = Locking @@ -291,7 +287,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = &new_owner ext.PendingID = signal.ID() - changes.Add(LockableExtType, "state", "req_id", "pending_owner", "pending_id") + AddChange[LockableExt](changes, "state", "req_id", "pending_owner", "pending_id") for id, state := range(ext.Requirements) { if state != Unlocked { ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") @@ -315,7 +311,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = nil ext.Owner = nil - changes.Add(LockableExtType, "state", "pending_owner", "owner") + AddChange[LockableExt](changes, "state", "pending_owner", "owner") messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else if source == *ext.Owner { ext.State = Unlocking @@ -323,7 +319,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.ReqID = &id ext.PendingOwner = nil ext.PendingID = signal.ID() - changes.Add(LockableExtType, "state", "pending_owner", "pending_id", "req_id") + AddChange[LockableExt](changes, "state", "pending_owner", "pending_id", "req_id") for id, state := range(ext.Requirements) { if state != Locked { ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") @@ -351,7 +347,7 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod wait_info, found := node.ProcessResponse(ext.WaitInfos, signal) if found == true { - changes.Add(LockableExtType, "wait_infos") + AddChange[LockableExt](changes, "wait_infos") state, found := ext.Requirements[wait_info.Destination] if found == true { ctx.Log.Logf("lockable", "%s timed out %s while %s was %s", wait_info.Destination, ReqStateStrings[state], node.ID, ReqStateStrings[state]) @@ -366,7 +362,7 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod } } if all_unlocked == true { - changes.Add(LockableExtType, "state") + AddChange[LockableExt](changes, "state") ext.State = Unlocked } case Locking: @@ -457,7 +453,7 @@ func (policy OwnerOfPolicy) ContinueAllows(ctx *Context, current PendingACL, sig } func (policy OwnerOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { - l_ext, err := GetExt[*LockableExt](node, LockableExtType) + l_ext, err := GetExt[LockableExt](node) if err != nil { ctx.Log.Logf("lockable", "OwnerOfPolicy.Allows called on node without LockableExt") return nil, Deny @@ -490,7 +486,7 @@ func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingAC return Deny } - ext, ok := sig.Extensions[LockableExtType] + ext, ok := sig.Extensions[ExtTypeFor[LockableExt]()] if ok == false { return Deny } diff --git a/lockable_test.go b/lockable_test.go index 8e8f235..86d2b70 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -7,11 +7,10 @@ import ( "crypto/rand" ) -var TestLockableType = NewNodeType("TEST_LOCKABLE") func lockableTestContext(t *testing.T, logs []string) *Context { ctx := logTestContext(t, logs) - err := ctx.RegisterNodeType(TestLockableType, []ExtType{LockableExtType}) + err := RegisterNodeType(ctx, "Lockable", []ExtType{ExtTypeFor[LockableExt]()}, map[string]FieldIndex{}) fatalErr(t, err) return ctx @@ -28,7 +27,7 @@ func TestLink(t *testing.T) { }) l2_listener := NewListenerExt(10) - l2, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{policy}, + l2, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, l2_listener, NewLockableExt(nil), ) @@ -36,7 +35,7 @@ func TestLink(t *testing.T) { l1_lockable := NewLockableExt(nil) l1_listener := NewListenerExt(10) - l1, err := NewNode(ctx, l1_key, TestLockableType, 10, nil, + l1, err := NewNode(ctx, l1_key, "Lockable", 10, nil, l1_listener, l1_lockable, ) @@ -76,11 +75,11 @@ func Test100Lock(t *testing.T) { listener_id := KeyID(l_pub) child_policy := NewPerNodePolicy(map[NodeID]Tree{ listener_id: { - SerializedType(LockSignalType): nil, + SerializedType(SignalTypeFor[LockSignal]()): nil, }, }) NewLockable := func()(*Node) { - l, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{child_policy}, + l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{child_policy}, NewLockableExt(nil), ) fatalErr(t, err) @@ -95,11 +94,11 @@ func Test100Lock(t *testing.T) { ctx.Log.Logf("test", "CREATED_100") l_policy := NewAllNodesPolicy(Tree{ - SerializedType(LockSignalType): nil, + SerializedType(SignalTypeFor[LockSignal]()): nil, }) listener := NewListenerExt(5000) - node, err := NewNode(ctx, listener_key, TestLockableType, 5000, []Policy{l_policy}, + node, err := NewNode(ctx, listener_key, "Lockable", 5000, []Policy{l_policy}, listener, NewLockableExt(reqs), ) @@ -128,7 +127,7 @@ func TestLock(t *testing.T) { NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(100) - l, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{policy}, + l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, listener, NewLockableExt(reqs), ) diff --git a/node.go b/node.go index 40f558b..57ec34a 100644 --- a/node.go +++ b/node.go @@ -49,6 +49,11 @@ func RandID() NodeID { type Changes map[ExtType][]string +func AddChange[E any, T interface { *E; Extension}](changes Changes, fields ...string) { + ext_type := ExtType(SerializedTypeFor[E]()) + changes.Add(ext_type, fields...) +} + func (changes Changes) Add(ext ExtType, fields ...string) { current, exists := changes[ext] if exists == false { @@ -265,6 +270,7 @@ func (err StringError) Error() string { func (err StringError) MarshalBinary() ([]byte, error) { return []byte(string(err)), nil } + func NewErrorField(fstring string, args ...interface{}) SerializedValue { str := StringError(fmt.Sprintf(fstring, args...)) str_ser, err := str.MarshalBinary() @@ -272,12 +278,13 @@ func NewErrorField(fstring string, args ...interface{}) SerializedValue { panic(err) } return SerializedValue{ - TypeStack: []SerializedType{ErrorType}, + TypeStack: []SerializedType{SerializedTypeFor[error]()}, Data: str_ser, } } func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue { + ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID) exts := map[ExtType]map[string]SerializedValue{} for ext_type, field_reqs := range(reqs) { fields := map[string]SerializedValue{} @@ -585,8 +592,9 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { return nil } -func GetCtx[C any](ctx *Context, ext_type ExtType) (C, error) { +func GetCtx[C any, E any, T interface { *E; Extension}](ctx *Context) (C, error) { var zero_ctx C + ext_type := ExtType(SerializedTypeFor[E]()) ext_info, ok := ctx.Extensions[ext_type] if ok == false { return zero_ctx, fmt.Errorf("%+v is not an extension in ctx", ext_type) @@ -600,8 +608,9 @@ func GetCtx[C any](ctx *Context, ext_type ExtType) (C, error) { return ext_ctx, nil } -func GetExt[T Extension](node *Node, ext_type ExtType) (T, error) { +func GetExt[E any, T interface { *E; Extension}](node *Node) (T, error) { var zero T + ext_type := ExtType(SerializedTypeFor[E]()) ext, exists := node.Extensions[ext_type] if exists == false { return zero, fmt.Errorf("%+v does not have %+v extension - %+v", node.ID, ext_type, node.Extensions) @@ -621,7 +630,12 @@ func KeyID(pub ed25519.PublicKey) NodeID { } // Create a new node in memory and start it's event loop -func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) { +func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) { + node_type, known_type := ctx.NodeTypes[type_name] + if known_type == false { + return nil, fmt.Errorf("%s is not a known node type", type_name) + } + var err error var public ed25519.PublicKey if key == nil { @@ -894,7 +908,7 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("0x%0x is not a known extension type", ext_type) } - ext_value, remaining, err := DeserializeValue(ctx, ext_info.Type, data) + ext_value, remaining, err := DeserializeValue(ctx, ext_info.Reflect, data) if err != nil { return nil, err } else if len(remaining) > 0 { diff --git a/node_test.go b/node_test.go index 3fa61c8..05e393f 100644 --- a/node_test.go +++ b/node_test.go @@ -10,16 +10,13 @@ import ( func TestNodeDB(t *testing.T) { ctx := logTestContext(t, []string{"signal", "serialize", "node", "db", "listener"}) - node_type := NewNodeType("test") - err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) - fatalErr(t, err) node_listener := NewListenerExt(10) - node, err := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener) + node, err := NewNode(ctx, nil, "Base", 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener) fatalErr(t, err) _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { - gql_changes, has_gql := sig.Changes[GQLExtType] + gql_changes, has_gql := sig.Changes[ExtTypeFor[GQLExt]()] if has_gql == true { return slices.Contains(gql_changes, "state") && sig.Source == node.ID } @@ -43,9 +40,6 @@ func TestNodeDB(t *testing.T) { func TestNodeRead(t *testing.T) { ctx := logTestContext(t, []string{"test"}) - node_type := NewNodeType("TEST") - err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) - fatalErr(t, err) n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) @@ -60,19 +54,19 @@ func TestNodeRead(t *testing.T) { n1_policy := NewPerNodePolicy(map[NodeID]Tree{ n2_id: { - SerializedType(ReadSignalType): nil, + SerializedType(SignalTypeFor[ReadSignal]()): nil, }, }) n2_listener := NewListenerExt(10) - n2, err := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), n2_listener) + n2, err := NewNode(ctx, n2_key, "Base", 10, nil, NewGroupExt(nil), n2_listener) fatalErr(t, err) - n1, err := NewNode(ctx, n1_key, node_type, 10, []Policy{n1_policy}, NewGroupExt(nil)) + n1, err := NewNode(ctx, n1_key, "Base", 10, []Policy{n1_policy}, NewGroupExt(nil)) fatalErr(t, err) read_sig := NewReadSignal(map[ExtType][]string{ - GroupExtType: {"members"}, + ExtTypeFor[GroupExt](): {"members"}, }) msgs := Messages{} msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig) diff --git a/policy.go b/policy.go index 33c30f8..2ac65b3 100644 --- a/policy.go +++ b/policy.go @@ -134,6 +134,6 @@ type AllNodesPolicy struct { } var DefaultPolicy = NewAllNodesPolicy(Tree{ - ResponseType: nil, - StatusType: nil, + SerializedType(SignalTypeFor[ResponseSignal]()): nil, + SerializedType(SignalTypeFor[StatusSignal]()): nil, }) diff --git a/serialize.go b/serialize.go index 05c0624..1b56187 100644 --- a/serialize.go +++ b/serialize.go @@ -1,34 +1,18 @@ package graphvent import ( - "crypto/sha512" - "encoding" - "encoding/binary" - "encoding/gob" - "fmt" - "math" - "reflect" - "sort" - "bytes" + "bytes" + "crypto/sha512" + "encoding" + "encoding/binary" + "encoding/gob" + "fmt" + "math" + "reflect" + "slices" + "sort" ) -const ( - TagBase = "GraphventTag" - ExtTypeBase = "ExtType" - NodeTypeBase = "NodeType" - SignalTypeBase = "SignalType" - PolicyTypeBase = "PolicyType" - SerializedTypeBase = "SerializedType" - FieldNameBase = "FieldName" -) - -func Hash(base string, name string) SerializedType { - digest := append([]byte(base), 0x00) - digest = append(digest, []byte(name)...) - hash := sha512.Sum512(digest) - return SerializedType(binary.BigEndian.Uint64(hash[0:8])) -} - type SerializedType uint64 func (t SerializedType) String() string { @@ -180,119 +164,69 @@ type SerializeFn func(*Context, reflect.Value) (Chunks, error) type TypeDeserializeFn func(*Context, []SerializedType) (reflect.Type, []SerializedType, error) type DeserializeFn func(*Context, reflect.Type, []byte) (reflect.Value, []byte, error) -func NewExtType(name string) ExtType { - return ExtType(Hash(ExtTypeBase, name)) -} - -func NewNodeType(name string) NodeType { - return NodeType(Hash(NodeTypeBase, name)) -} - -func NewSignalType(name string) SignalType { - return SignalType(Hash(SignalTypeBase, name)) -} - -func NewPolicyType(name string) PolicyType { - return PolicyType(Hash(PolicyTypeBase, name)) -} - -func NewFieldTag(tag_string string) FieldTag { - return FieldTag(Hash(FieldNameBase, tag_string)) -} - -func NewSerializedType(name string) SerializedType { - return Hash(SerializedTypeBase, name) -} - -var ( - ListenerExtType = NewExtType("LISTENER") - LockableExtType = NewExtType("LOCKABLE") - GQLExtType = NewExtType("GQL") - GroupExtType = NewExtType("GROUP") - ACLExtType = NewExtType("ACL") - EventExtType = NewExtType("EVENT") - - GQLNodeType = NewNodeType("GQL") - BaseNodeType = NewNodeType("BASE") - GroupNodeType = NewNodeType("GROUP") - - StopSignalType = NewSignalType("STOP") - CreateSignalType = NewSignalType("CREATE") - StartSignalType = NewSignalType("START") - StatusSignalType = NewSignalType("STATUS") - LinkSignalType = NewSignalType("LINK") - LockSignalType = NewSignalType("LOCK") - TimeoutSignalType = NewSignalType("TIMEOUT") - ReadSignalType = NewSignalType("READ") - ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT") - ErrorSignalType = NewSignalType("ERROR") - SuccessSignalType = NewSignalType("SUCCESS") - ReadResultSignalType = NewSignalType("READ_RESULT") - RemoveMemberSignalType = NewSignalType("REMOVE_MEMBER") - AddMemberSignalType = NewSignalType("ADD_MEMBER") - ACLSignalType = NewSignalType("ACL") - AddSubGroupSignalType = NewSignalType("ADD_SUBGROUP") - RemoveSubGroupSignalType = NewSignalType("REMOVE_SUBGROUP") - StoppedSignalType = NewSignalType("STOPPED") - EventControlSignalType = NewSignalType("EVENT_CONTORL") - EventStateSignalType = NewSignalType("VEX_MATCH_STATUS") - - MemberOfPolicyType = NewPolicyType("MEMBER_OF") - OwnerOfPolicyType = NewPolicyType("OWNER_OF") - ParentOfPolicyType = NewPolicyType("PARENT_OF") - RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") - PerNodePolicyType = NewPolicyType("PER_NODE") - AllNodesPolicyType = NewPolicyType("ALL_NODES") - ACLProxyPolicyType = NewPolicyType("ACL_PROXY") - - ErrorType = NewSerializedType("ERROR") - PointerType = NewSerializedType("POINTER") - SliceType = NewSerializedType("SLICE") - StructType = NewSerializedType("STRUCT") - IntType = NewSerializedType("INT") - UIntType = NewSerializedType("UINT") - BoolType = NewSerializedType("BOOL") - Float64Type = NewSerializedType("FLOAT64") - Float32Type = NewSerializedType("FLOAT32") - UInt8Type = NewSerializedType("UINT8") - UInt16Type = NewSerializedType("UINT16") - UInt32Type = NewSerializedType("UINT32") - UInt64Type = NewSerializedType("UINT64") - Int8Type = NewSerializedType("INT8") - Int16Type = NewSerializedType("INT16") - Int32Type = NewSerializedType("INT32") - Int64Type = NewSerializedType("INT64") - StringType = NewSerializedType("STRING") - ArrayType = NewSerializedType("ARRAY") - InterfaceType = NewSerializedType("INTERFACE") - MapType = NewSerializedType("MAP") - - EventStateType = NewSerializedType("EVENT_STATE") - WaitReasonType = NewSerializedType("WAIT_REASON") - EventCommandType = NewSerializedType("EVENT_COMMAND") - ReqStateType = NewSerializedType("REQ_STATE") - WaitInfoType = NewSerializedType("WAIT_INFO") - SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") - NodeStructType = NewSerializedType("NODE_STRUCT") - QueuedSignalType = NewSerializedType("QUEUED_SIGNAL") - NodeTypeSerialized = NewSerializedType("NODE_TYPE") - ChangesSerialized = NewSerializedType("CHANGES") - ExtTypeSerialized = NewSerializedType("EXT_TYPE") - PolicyTypeSerialized = NewSerializedType("POLICY_TYPE") - ExtSerialized = NewSerializedType("EXTENSION") - PolicySerialized = NewSerializedType("POLICY") - SignalSerialized = NewSerializedType("SIGNAL") - NodeIDType = NewSerializedType("NODE_ID") - UUIDType = NewSerializedType("UUID") - PendingACLType = NewSerializedType("PENDING_ACL") - PendingACLSignalType = NewSerializedType("PENDING_ACL_SIGNAL") - TimeType = NewSerializedType("TIME") - DurationType = NewSerializedType("DURATION") - ResponseType = NewSerializedType("RESPONSE") - StatusType = NewSerializedType("STATUS") - TreeType = NewSerializedType("TREE") - SerializedTypeSerialized = NewSerializedType("SERIALIZED_TYPE") -) + +func NodeTypeFor(name string, extensions []ExtType, mappings map[string]FieldIndex) NodeType { + digest := []byte("GRAPHVENT_NODE[" + name + "] - ") + for _, ext := range(extensions) { + digest = binary.BigEndian.AppendUint64(digest, uint64(ext)) + } + + digest = binary.BigEndian.AppendUint64(digest, 0) + + sorted_keys := make([]string, len(mappings)) + i := 0 + for key := range(mappings) { + sorted_keys[i] = key + i += 1 + } + slices.Sort(sorted_keys) + + + + for _, key := range(sorted_keys) { + digest = append(digest, []byte(key + ":")...) + digest = binary.BigEndian.AppendUint64(digest, uint64(mappings[key].Extension)) + digest = append(digest, []byte(mappings[key].Field + "|")...) + } + + hash := sha512.Sum512(digest) + return NodeType(binary.BigEndian.Uint64(hash[0:8])) +} + +func SerializedKindFor(kind reflect.Kind) SerializedType { + digest := []byte("GRAPHVENT_KIND - " + kind.String()) + hash := sha512.Sum512(digest) + return SerializedType(binary.BigEndian.Uint64(hash[0:8])) +} + +func SerializedTypeFor[T any]() SerializedType { + t := reflect.TypeFor[T]() + digest := []byte(t.PkgPath() + ":" + t.Name()) + hash := sha512.Sum512(digest) + return SerializedType(binary.BigEndian.Uint64(hash[0:8])) +} + +func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType { + return ExtType(SerializedTypeFor[E]()) +} + +func SignalTypeFor[S Signal]() SignalType { + return SignalType(SerializedTypeFor[S]()) +} + +func PolicyTypeFor[P Policy]() PolicyType { + return PolicyType(SerializedTypeFor[P]()) +} + +func Hash(base, data string) SerializedType { + digest := []byte(base + ":" + data) + hash := sha512.Sum512(digest) + return SerializedType(binary.BigEndian.Uint64(hash[0:8])) +} + +func GetFieldTag(tag string) FieldTag { + return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag)) +} type FieldInfo struct { Index []int @@ -302,8 +236,8 @@ type FieldInfo struct { type StructInfo struct { Type reflect.Type - FieldOrder []SerializedType - FieldMap map[SerializedType]FieldInfo + FieldOrder []FieldTag + FieldMap map[FieldTag]FieldInfo PostDeserialize bool PostDeserializeIdx int } @@ -316,15 +250,15 @@ var deserializable_zero Deserializable = nil var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() func GetStructInfo(ctx *Context, struct_type reflect.Type) (StructInfo, error) { - field_order := []SerializedType{} - field_map := map[SerializedType]FieldInfo{} + field_order := []FieldTag{} + field_map := map[FieldTag]FieldInfo{} for _, field := range reflect.VisibleFields(struct_type) { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv == false { continue } else { - field_hash := Hash(FieldNameBase, gv_tag) - _, exists := field_map[field_hash] + field_tag := GetFieldTag(gv_tag) + _, exists := field_map[field_tag] if exists == true { return StructInfo{}, fmt.Errorf("gv tag %s is repeated", gv_tag) } else { @@ -332,12 +266,12 @@ func GetStructInfo(ctx *Context, struct_type reflect.Type) (StructInfo, error) { if err != nil { return StructInfo{}, err } - field_map[field_hash] = FieldInfo{ + field_map[field_tag] = FieldInfo{ field.Index, field_type_stack, field.Type, } - field_order = append(field_order, field_hash) + field_order = append(field_order, field_tag) } } } @@ -408,10 +342,10 @@ func DeserializeStruct(info StructInfo)func(*Context, reflect.Type, []byte)(refl for i := uint64(0); i < num_fields; i ++ { field_hash_bytes := data[:8] data = data[8:] - field_hash := SerializedType(binary.BigEndian.Uint64(field_hash_bytes)) - field_info, exists := info.FieldMap[field_hash] + field_tag := FieldTag(binary.BigEndian.Uint64(field_hash_bytes)) + field_info, exists := info.FieldMap[field_tag] if exists == false { - return reflect.Value{}, nil, fmt.Errorf("%+v is not a field in %+v", field_hash, info.Type) + return reflect.Value{}, nil, fmt.Errorf("%+v is not a field in %+v", field_tag, info.Type) } var field_value reflect.Value diff --git a/serialize_test.go b/serialize_test.go index 15796b4..e969088 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -1,10 +1,10 @@ package graphvent import ( - "testing" - "reflect" - "fmt" - "time" + "fmt" + "reflect" + "testing" + "time" ) func TestSerializeTest(t *testing.T) { @@ -17,11 +17,11 @@ func TestSerializeTest(t *testing.T) { } func TestSerializeBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize"}) + ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) testSerializeComparable[bool](t, ctx, true) type bool_wrapped bool - err := ctx.RegisterType(reflect.TypeOf(bool_wrapped(true)), NewSerializedType("BOOL_WRAPPED"), nil, nil, nil, DeserializeBool[bool_wrapped]) + err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped]) fatalErr(t, err) testSerializeComparable[bool_wrapped](t, ctx, true) @@ -55,9 +55,9 @@ func TestSerializeBasic(t *testing.T) { }) testSerialize(t, ctx, Tree{ - NodeTypeSerialized: nil, - SerializedTypeSerialized: Tree{ - NodeTypeSerialized: Tree{}, + SerializedTypeFor[NodeType](): nil, + SerializedTypeFor[SerializedType](): { + SerializedTypeFor[NodeType](): Tree{}, }, }) @@ -83,11 +83,7 @@ func TestSerializeBasic(t *testing.T) { String string `gv:"string"` } - test_struct_type := reflect.TypeOf(test_struct{}) - test_struct_info, err := GetStructInfo(ctx, test_struct_type) - fatalErr(t, err) - - err = ctx.RegisterType(test_struct_type, NewSerializedType("TEST_STRUCT"), nil, SerializeStruct(test_struct_info), nil, DeserializeStruct(test_struct_info)) + err = RegisterStruct[test_struct](ctx) fatalErr(t, err) testSerialize(t, ctx, test_struct{ @@ -96,25 +92,24 @@ func TestSerializeBasic(t *testing.T) { }) testSerialize(t, ctx, Tree{ - MapType: nil, - StringType: nil, + SerializedKindFor(reflect.Map): nil, + SerializedKindFor(reflect.String): nil, }) testSerialize(t, ctx, Tree{ - TreeType: nil, + SerializedTypeFor[Tree](): nil, }) testSerialize(t, ctx, Tree{ - TreeType: { - ErrorType: Tree{}, - MapType: nil, + SerializedTypeFor[Tree](): { + SerializedTypeFor[error](): Tree{}, + SerializedKindFor(reflect.Map): nil, }, - StringType: nil, + SerializedKindFor(reflect.String): nil, }) type test_slice []string - test_slice_type := reflect.TypeOf(test_slice{}) - err = ctx.RegisterType(test_slice_type, NewSerializedType("TEST_SLICE"), SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice) + err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice) fatalErr(t, err) testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"}) @@ -133,12 +128,8 @@ func (s test) String() string { func TestSerializeStructTags(t *testing.T) { ctx := logTestContext(t, []string{"test"}) - test_type := NewSerializedType("TEST_STRUCT") - test_struct_type := reflect.TypeOf(test{}) - ctx.Log.Logf("test", "TEST_TYPE: %+v", test_type) - test_struct_info, err := GetStructInfo(ctx, test_struct_type) + err := RegisterStruct[test](ctx) fatalErr(t, err) - ctx.RegisterType(test_struct_type, test_type, nil, SerializeStruct(test_struct_info), nil, DeserializeStruct(test_struct_info)) test_int := 10 test_string := "test" diff --git a/signal.go b/signal.go index 4f8b4e6..ba9bd1d 100644 --- a/signal.go +++ b/signal.go @@ -149,7 +149,7 @@ type CreateSignal struct { func (signal CreateSignal) Permission() Tree { return Tree{ - SerializedType(CreateSignalType): nil, + SerializedType(SignalTypeFor[CreateSignal]()): nil, } } @@ -164,7 +164,7 @@ type StartSignal struct { } func (signal StartSignal) Permission() Tree { return Tree{ - SerializedType(StartSignalType): nil, + SerializedType(SignalTypeFor[StartSignal]()): nil, } } func NewStartSignal() *StartSignal { @@ -179,7 +179,7 @@ type StoppedSignal struct { } func (signal StoppedSignal) Permission() Tree { return Tree{ - ResponseType: nil, + SerializedType(SignalTypeFor[ResponseSignal]()): nil, } } func NewStoppedSignal(sig *StopSignal, source NodeID) *StoppedSignal { @@ -194,7 +194,7 @@ type StopSignal struct { } func (signal StopSignal) Permission() Tree { return Tree{ - SerializedType(StopSignalType): nil, + SerializedType(SignalTypeFor[StopSignal]()): nil, } } func NewStopSignal() *StopSignal { @@ -213,8 +213,8 @@ func (signal SuccessSignal) String() string { func (signal SuccessSignal) Permission() Tree { return Tree{ - ResponseType: { - SerializedType(SuccessSignalType): nil, + SerializedType(SignalTypeFor[ResponseSignal]()): { + SerializedType(SignalTypeFor[SuccessSignal]()): nil, }, } } @@ -233,8 +233,8 @@ func (signal ErrorSignal) String() string { } func (signal ErrorSignal) Permission() Tree { return Tree{ - ResponseType: { - SerializedType(ErrorSignalType): nil, + SerializedType(SignalTypeFor[ResponseSignal]()): { + SerializedType(SignalTypeFor[ErrorSignal]()): nil, }, } } @@ -250,7 +250,7 @@ type ACLTimeoutSignal struct { } func (signal ACLTimeoutSignal) Permission() Tree { return Tree{ - SerializedType(ACLTimeoutSignalType): nil, + SerializedType(SignalTypeFor[ACLTimeoutSignal]()): nil, } } func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { @@ -267,7 +267,7 @@ type StatusSignal struct { } func (signal StatusSignal) Permission() Tree { return Tree{ - StatusType: nil, + SerializedType(SignalTypeFor[StatusSignal]()): nil, } } func (signal StatusSignal) String() string { @@ -294,7 +294,7 @@ const ( func (signal LinkSignal) Permission() Tree { return Tree{ - SerializedType(LinkSignalType): Tree{ + SerializedType(SignalTypeFor[LinkSignal]()): Tree{ Hash(LinkActionBase, signal.Action): nil, }, } @@ -322,7 +322,7 @@ const ( func (signal LockSignal) Permission() Tree { return Tree{ - SerializedType(LockSignalType): Tree{ + SerializedType(SignalTypeFor[LockSignal]()): Tree{ Hash(LockStateBase, signal.State): nil, }, } @@ -349,11 +349,11 @@ func (signal ReadSignal) Permission() Tree { for ext, fields := range(signal.Extensions) { field_tree := Tree{} for _, field := range(fields) { - field_tree[Hash(FieldNameBase, field)] = nil + field_tree[SerializedType(GetFieldTag(field))] = nil } ret[SerializedType(ext)] = field_tree } - return Tree{SerializedType(ReadSignalType): ret} + return Tree{SerializedType(SignalTypeFor[ReadSignal]()): ret} } func NewReadSignal(exts map[ExtType][]string) *ReadSignal { @@ -376,8 +376,8 @@ func (signal ReadResultSignal) String() string { func (signal ReadResultSignal) Permission() Tree { return Tree{ - ResponseType: { - SerializedType(ReadResultSignalType): nil, + SerializedType(SignalTypeFor[ResponseSignal]()): { + SerializedType(SignalTypeFor[ReadResultSignal]()): nil, }, } }