diff --git a/context.go b/context.go index 0519a77..a0d37b0 100644 --- a/context.go +++ b/context.go @@ -278,6 +278,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = gql_ctx.RegisterField("Listen", GQLExtType, "listen") + if err != nil { + return nil, err + } + schema, err := BuildSchema(gql_ctx) if err != nil { return nil, err diff --git a/gql.go b/gql.go index 0dc4c89..c8b2b9d 100644 --- a/gql.go +++ b/gql.go @@ -465,11 +465,11 @@ func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init fu Name: string(node_type), Interfaces: interfaces, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - node, ok := p.Value.(*Node) + node, ok := p.Value.(NodeResult) if ok == false { return false } - return node.Type == node_type + return node.Result.NodeType == node_type }, Fields: graphql.Fields{}, }) @@ -609,6 +609,7 @@ func NewGQLExtContext() *GQLExtContext { Subscription: subscription, NodeTypes: map[NodeType]*graphql.Object{}, Interfaces: []*Interface{}, + Fields: map[string]Field{}, } var err error diff --git a/gql_query.go b/gql_query.go index 184ca03..5a37e6b 100644 --- a/gql_query.go +++ b/gql_query.go @@ -1,22 +1,40 @@ package graphvent import ( "time" + "reflect" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" ) -func GetFieldNames(p graphql.ResolveParams) []string { +func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { names := []string{} + if selection_set == nil { + return names + } - for _, node := range(p.Info.FieldASTs) { - for _, sel := range(node.SelectionSet.Selections) { - names = append(names, sel.(*ast.Field).Name.Value) + for _, sel := range(selection_set.Selections) { + switch field := sel.(type) { + case *ast.Field: + names = append(names, field.Name.Value) + case *ast.InlineFragment: + names = append(names, GetFieldNames(ctx, field.SelectionSet)...) + default: + ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field)) } } return names } +func GetResolveFields(ctx *Context, p graphql.ResolveParams) []string { + names := []string{} + for _, field := range(p.Info.FieldASTs) { + names = append(names, GetFieldNames(ctx, field.SelectionSet)...) + } + + return names +} + var QueryNode = &graphql.Field{ Type: InterfaceNode.Interface, Args: graphql.FieldConfigArgument{ @@ -36,7 +54,7 @@ var QueryNode = &graphql.Field{ return nil, err } - fields := GetFieldNames(p) + fields := GetResolveFields(ctx.Context, p) ctx.Context.Log.Logf("gql", "RESOLVE_NODE(%s): %+v", id, fields) // Get a list of fields that will be written @@ -77,7 +95,7 @@ var QuerySelf = &graphql.Field{ return nil, err } - ctx.Context.Log.Logf("gql", "FIELDS: %+v", GetFieldNames(p)) + ctx.Context.Log.Logf("gql", "FIELDS: %+v", GetResolveFields(ctx.Context, p)) return nil, nil }, diff --git a/gql_test.go b/gql_test.go index 5887856..5c60870 100644 --- a/gql_test.go +++ b/gql_test.go @@ -23,7 +23,7 @@ func TestGQL(t *testing.T) { fatalErr(t, err) listener_ext := NewListenerExt(10) policy := NewAllNodesPolicy(Actions{MakeAction("+")}) - gql := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), gql_ext) + gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil)) n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), listener_ext) ctx.Send(n1.ID, gql.ID, StateSignal{NewDirectSignal(GQLStateSignalType), "start_server"}) @@ -39,29 +39,43 @@ func TestGQL(t *testing.T) { port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port url := fmt.Sprintf("https://localhost:%d/gql", port) - ser, err := json.MarshalIndent(&GQLPayload{ + req_1 := GQLPayload{ Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }", Variables: map[string]interface{}{ "id": n1.ID.String(), }, - }, "", " ") - fatalErr(t, err) + } + + req_2 := GQLPayload{ + Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQL { Listen } } }", + Variables: map[string]interface{}{ + "id": gql.ID.String(), + }, + } - req_data := bytes.NewBuffer(ser) + SendGQL := func(payload GQLPayload) []byte { + ser, err := json.MarshalIndent(&payload, "", " ") + fatalErr(t, err) - req, err := http.NewRequest("GET", url, req_data) - req.SetBasicAuth(n1.ID.String(), "BAD_PASSWORD") - fatalErr(t, err) + req_data := bytes.NewBuffer(ser) + req, err := http.NewRequest("GET", url, req_data) + fatalErr(t, err) - resp, err := client.Do(req) - fatalErr(t, err) + req.SetBasicAuth(n1.ID.String(), "BAD_PASSWORD") + resp, err := client.Do(req) + fatalErr(t, err) - body, err := io.ReadAll(resp.Body) - fatalErr(t, err) + body, err := io.ReadAll(resp.Body) + fatalErr(t, err) - resp.Body.Close() + resp.Body.Close() + return body + } - ctx.Log.Logf("test", "TEST_RESP: %s", body) + resp_1 := SendGQL(req_1) + ctx.Log.Logf("test", "RESP_1: %s", resp_1) + resp_2 := SendGQL(req_2) + ctx.Log.Logf("test", "RESP_2: %s", resp_2) ctx.Send(n1.ID, gql.ID, StateSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StateSignal) bool {