From 8f9a759b260e3ea3152eaee67142c0cd2cfab4cb Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 17 Mar 2024 14:25:34 -0600 Subject: [PATCH] Added GQL enum --- context.go | 59 +++++++++++++------- gql.go | 2 +- gql_node.go | 57 +++++++++++--------- gql_test.go | 153 +++++++++++++++++++++++++++++++++------------------- lockable.go | 2 +- 5 files changed, 172 insertions(+), 101 deletions(-) diff --git a/context.go b/context.go index cf5573b..19211f9 100644 --- a/context.go +++ b/context.go @@ -149,7 +149,6 @@ type Pair struct { } func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) error { - ctx.Log.Logf("gql", "Registering map %s with node_type %s", reflect_type, node_type) node_types := strings.SplitN(node_type, ":", 2) if len(node_types) != 2 { @@ -173,7 +172,6 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro gql_name := strings.ReplaceAll(reflect_type.String(), ".", "_") gql_name = strings.ReplaceAll(gql_name, "[", "_") gql_name = strings.ReplaceAll(gql_name, "]", "_") - ctx.Log.Logf("gql", "Registering %s with gql name %s", reflect_type, gql_name) gql_pair := graphql.NewObject(graphql.ObjectConfig{ Name: gql_name, @@ -211,7 +209,6 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro }, }) - ctx.Log.Logf("gql", "Registering new map with pair type %+v", gql_pair) gql_map := graphql.NewList(gql_pair) serialized_type := SerializeType(reflect_type) @@ -248,14 +245,13 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, ctx.Log.Logf("gql", "Building Schema") for _, info := range(ctx.TypeMap) { - ctx.Log.Logf("gql", "Adding type %+v", info.Type) - types = append(types, info.Type) + if info.Type != nil { + types = append(types, info.Type) + } } for _, info := range(ctx.Nodes) { - ctx.Log.Logf("gql", "Adding node type object %+v", info.Type) types = append(types, info.Type) - ctx.Log.Logf("gql", "Adding node type interface %+v", info.Interface) types = append(types, info.Interface) } @@ -329,8 +325,6 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i } } - ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type) - ctx.Extensions[ext_type] = &ExtensionInfo{ ExtType: ext_type, Data: data, @@ -342,7 +336,6 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i } func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { - ctx.Log.Logf("gql", "Registering NodeType %s with extensions %+v", name, extensions) node_type := NodeTypeFor(extensions) _, exists := ctx.Nodes[node_type] if exists == true { @@ -373,7 +366,7 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { fields[field_name] = extension } } - + gql_interface := graphql.NewInterface(graphql.InterfaceConfig{ Name: name, Fields: graphql.Fields{ @@ -522,7 +515,6 @@ func RegisterObject[T any](ctx *Context) error { } gql_name := strings.ReplaceAll(reflect_type.String(), ".", "_") - ctx.Log.Logf("gql", "Registering %s with gql name %s", reflect_type, gql_name) gql := graphql.NewObject(graphql.ObjectConfig{ Name: gql_name, IsTypeOf: func(p graphql.IsTypeOfParams) bool { @@ -549,14 +541,12 @@ func RegisterObject[T any](ctx *Context) error { NodeTag: node_tag, Tag: gv_tag, } - gql_type, err := ctx.GQLType(field.Type, node_tag) if err != nil { return err } gql_resolve := ctx.GQLResolve(field.Type, node_tag) - gql.AddFieldConfig(gv_tag, &graphql.Field{ Type: gql_type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { @@ -564,7 +554,7 @@ func RegisterObject[T any](ctx *Context) error { if ok == false { return nil, fmt.Errorf("%s is not %s", reflect.TypeOf(p.Source), reflect_type) } - + value, err := reflect.ValueOf(val).FieldByIndexErr(field.Index) if err != nil { return nil, err @@ -709,6 +699,40 @@ func astInt[T constraints.Integer](value ast.Value) interface{} { } } +func RegisterEnum[E comparable](ctx *Context, str_map map[E]string) error { + reflect_type := reflect.TypeFor[E]() + serialized_type := SerializedTypeFor[E]() + + _, exists := ctx.TypeTypes[reflect_type] + if exists { + return fmt.Errorf("%+v already registered in TypeMap", reflect_type) + } + + value_config := graphql.EnumValueConfigMap{} + + for value, value_name := range(str_map) { + value_config[value_name] = &graphql.EnumValueConfig{ + Value: value, + } + } + + gql_name := strings.ReplaceAll(reflect_type.String(), ".", "_") + gql := graphql.NewEnum(graphql.EnumConfig{ + Name: gql_name, + Values: value_config, + }) + + ctx.TypeMap[serialized_type] = &TypeInfo{ + Serialized: serialized_type, + Reflect: reflect_type, + Type: gql, + Resolve: nil, + } + ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + + return nil +} + func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, resolve func(interface{},graphql.ResolveParams)(interface{},error)) error { reflect_type := reflect.TypeFor[S]() serialized_type := SerializedTypeFor[S]() @@ -719,7 +743,6 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f } gql_name := strings.ReplaceAll(reflect_type.String(), ".", "_") - ctx.Log.Logf("gql", "Registering %s with gql name %s", reflect_type, gql_name) gql := graphql.NewScalar(graphql.ScalarConfig{ Name: gql_name, Serialize: to_json, @@ -966,7 +989,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = RegisterScalar[ReqState](ctx, identity, coerce[ReqState], astInt[ReqState], nil) + err = RegisterEnum[ReqState](ctx, ReqStateStrings) if err != nil { return nil, err } @@ -1040,7 +1063,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Name: "Query", Fields: graphql.Fields{ "Self": &graphql.Field{ - Type: ctx.NodeTypes["Lockable"].Interface, + Type: ctx.NodeTypes["Base"].Interface, Resolve: func(p graphql.ResolveParams) (interface{}, error) { ctx, err := PrepResolve(p) if err != nil { diff --git a/gql.go b/gql.go index 05797eb..0a51ca0 100644 --- a/gql.go +++ b/gql.go @@ -544,7 +544,7 @@ type GQLExt struct { State string `gv:"state"` TLSKey []byte `gv:"tls_key"` TLSCert []byte `gv:"tls_cert"` - Listen string `gv:"listen"` + Listen string `gv:"listen" gql:"GQLListen"` } func (ext *GQLExt) Load(ctx *Context, node *Node) error { diff --git a/gql_node.go b/gql_node.go index 9ea187e..3ef0991 100644 --- a/gql_node.go +++ b/gql_node.go @@ -25,17 +25,36 @@ func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) { return uint64(node.NodeType), nil } -func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { - names := []string{} +type FieldIndex struct { + Extension ExtType + Tag string +} + +func GetFields(ctx *Context, node_type string, selection_set *ast.SelectionSet) []FieldIndex { + names := []FieldIndex{} if selection_set == nil { return names } + node_info, mapped := ctx.NodeTypes[node_type] + if mapped == false { + return nil + } + for _, sel := range(selection_set.Selections) { switch field := sel.(type) { case *ast.Field: - names = append(names, field.Name.Value) + if field.Name.Value == "ID" || field.Name.Value == "Type" { + continue + } + + extension, mapped := node_info.Fields[field.Name.Value] + if mapped == false { + continue + } + names = append(names, FieldIndex{extension, field.Name.Value}) case *ast.InlineFragment: + names = append(names, GetFields(ctx, field.TypeCondition.Name.Value, field.SelectionSet)...) default: ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field)) } @@ -46,47 +65,33 @@ func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { // Returns the fields that need to be resolved func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) (map[ExtType][]string, error) { - node_info, mapped := ctx.Context.NodeTypes[p.Info.ReturnType.Name()] - if mapped == false { - return nil, fmt.Errorf("No NodeType %s", p.Info.ReturnType.Name()) - } - - fields := map[ExtType][]string{} - names := []string{} + m := map[ExtType][]string{} + fields := []FieldIndex{} for _, field := range(p.Info.FieldASTs) { - names = append(names, GetFieldNames(ctx.Context, field.SelectionSet)...) + fields = append(fields, GetFields(ctx.Context, p.Info.ReturnType.Name(), field.SelectionSet)...) } cache, node_cached := ctx.NodeCache[id] - for _, name := range(names) { - if name == "ID" || name == "Type" { - continue - } - - ext_type, field_mapped := node_info.Fields[name] - if field_mapped == false { - return nil, fmt.Errorf("NodeType %s does not have field %s", p.Info.ReturnType.Name(), name) - } - - ext_fields, exists := fields[ext_type] + for _, field := range(fields) { + ext_fields, exists := m[field.Extension] if exists == false { ext_fields = []string{} } if node_cached { - ext_cache, ext_cached := cache.Data[ext_type] + ext_cache, ext_cached := cache.Data[field.Extension] if ext_cached { - _, field_cached := ext_cache[name] + _, field_cached := ext_cache[field.Tag] if field_cached { continue } } } - fields[ext_type] = append(ext_fields, name) + m[field.Extension] = append(ext_fields, field.Tag) } - return fields, nil + return m, nil } func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { diff --git a/gql_test.go b/gql_test.go index 05778c4..de4b31f 100644 --- a/gql_test.go +++ b/gql_test.go @@ -1,26 +1,31 @@ package graphvent import ( - "testing" - "fmt" - "encoding/json" - "io" - "net/http" - "net" - "crypto/tls" - "bytes" - "golang.org/x/net/websocket" - "github.com/google/uuid" + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "reflect" + "testing" + "time" + + "github.com/google/uuid" + "golang.org/x/net/websocket" ) -func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "gqlws", "gql", "gql_subscribe"}) - gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) +func TestGQLSubscribe(t *testing.T) { + ctx := logTestContext(t, []string{"test"}) + + n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) fatalErr(t, err) listener_ext := NewListenerExt(10) - n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) + + gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext) @@ -29,52 +34,14 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID) - skipVerifyTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + sub_1 := GQLPayload{ + Query: "subscription Self { Self { ID, Type } }", } - client := &http.Client{Transport: skipVerifyTransport} + port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port url := fmt.Sprintf("http://localhost:%d/gql", port) ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port) - req_1 := GQLPayload{ - Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type } }", - Variables: map[string]interface{}{ - "id": n1.ID.String(), - }, - } - - req_2 := GQLPayload{ - Query: "query Self { Self { ID, Type } }", - } - - SendGQL := func(payload GQLPayload) []byte { - ser, err := json.MarshalIndent(&payload, "", " ") - fatalErr(t, err) - - req_data := bytes.NewBuffer(ser) - req, err := http.NewRequest("GET", url, req_data) - fatalErr(t, err) - - resp, err := client.Do(req) - fatalErr(t, err) - - body, err := io.ReadAll(resp.Body) - fatalErr(t, err) - - resp.Body.Close() - return body - } - - resp_1 := SendGQL(req_1) - ctx.Log.Logf("test", "RESP_1: %s", resp_1) - resp_2 := SendGQL(req_2) - ctx.Log.Logf("test", "RESP_2: %s", resp_2) - - sub_1 := GQLPayload{ - Query: "subscription Self { Self { ID, Type } }", - } - SubGQL := func(payload GQLPayload) { config, err := websocket.NewConfig(ws_url, url) fatalErr(t, err) @@ -147,6 +114,82 @@ func TestGQLServer(t *testing.T) { SubGQL(sub_1) } +func TestGQLQuery(t *testing.T) { + ctx := logTestContext(t, []string{"test", "lockable"}) + + n1_listener := NewListenerExt(10) + n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil), n1_listener) + fatalErr(t, err) + + gql_listener := NewListenerExt(10) + gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) + fatalErr(t, err) + + gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener) + fatalErr(t, err) + + ctx.Log.Logf("test", "GQL: %s", gql.ID) + ctx.Log.Logf("test", "NODE: %s", n1.ID) + + skipVerifyTransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: skipVerifyTransport} + port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port + url := fmt.Sprintf("http://localhost:%d/gql", port) + + req_1 := GQLPayload{ + Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { lockable_state } } }", + Variables: map[string]interface{}{ + "id": n1.ID.String(), + }, + } + + req_2 := GQLPayload{ + Query: "query Self { Self { ID, Type, ... on Lockable { lockable_state, requirements { Key { ID ... on Lockable { lockable_state } } } } } }", + } + + SendGQL := func(payload GQLPayload) []byte { + ser, err := json.MarshalIndent(&payload, "", " ") + fatalErr(t, err) + + req_data := bytes.NewBuffer(ser) + req, err := http.NewRequest("GET", url, req_data) + fatalErr(t, err) + + resp, err := client.Do(req) + fatalErr(t, err) + + body, err := io.ReadAll(resp.Body) + fatalErr(t, err) + + resp.Body.Close() + return body + } + + resp_1 := SendGQL(req_1) + ctx.Log.Logf("test", "RESP_1: %s", resp_1) + resp_2 := SendGQL(req_2) + ctx.Log.Logf("test", "RESP_2: %s", resp_2) + + lock_id, err := LockLockable(ctx, n1) + fatalErr(t, err) + + response, _, err := WaitForResponse(n1_listener.Chan, 100*time.Millisecond, lock_id) + fatalErr(t, err) + switch response := response.(type) { + case *SuccessSignal: + default: + t.Fatalf("Wrong response: %s", reflect.TypeOf(response)) + } + + resp_3 := SendGQL(req_1) + ctx.Log.Logf("test", "RESP_3: %s", resp_3) + + resp_4 := SendGQL(req_2) + ctx.Log.Logf("test", "RESP_4: %s", resp_4) +} + func TestGQLDB(t *testing.T) { ctx := logTestContext(t, []string{"test", "db", "node"}) diff --git a/lockable.go b/lockable.go index b7e8a60..05ab54c 100644 --- a/lockable.go +++ b/lockable.go @@ -23,7 +23,7 @@ var ReqStateStrings = map[ReqState]string { } type LockableExt struct{ - State ReqState `gv:"state"` + State ReqState `gv:"lockable_state"` ReqID *uuid.UUID `gv:"req_id"` Owner *NodeID `gv:"owner" node:"Base"` PendingOwner *NodeID `gv:"pending_owner" node:"Base"`