diff --git a/context.go b/context.go index 73e9409..cf5573b 100644 --- a/context.go +++ b/context.go @@ -150,7 +150,7 @@ type Pair struct { func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) error { ctx.Log.Logf("gql", "Registering map %s with node_type %s", reflect_type, node_type) - node_types := strings.Split(node_type, ":") + node_types := strings.SplitN(node_type, ":", 2) if len(node_types) != 2 { return fmt.Errorf("Invalid node tag for map type %s: \"%s\"", reflect_type, node_type) @@ -782,13 +782,31 @@ func (ctx *Context) Stop() { ctx.nodeMapLock.Unlock() } +func (ctx *Context) Load(id NodeID) (*Node, error) { + node, err := LoadNode(ctx, id) + if err != nil { + return nil, err + } + + ctx.AddNode(id, node) + started := make(chan error, 1) + go runNode(ctx, node, started) + + err = <- started + if err != nil { + return nil, err + } + + return node, nil +} + // Get a node from the context, or load from the database if not loaded func (ctx *Context) getNode(id NodeID) (*Node, error) { target, exists := ctx.Node(id) if exists == false { var err error - target, err = LoadNode(ctx, id) + target, err = ctx.Load(id) if err != nil { return nil, err } diff --git a/gql.go b/gql.go index 82f3f13..05797eb 100644 --- a/gql.go +++ b/gql.go @@ -336,10 +336,10 @@ func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result { if operation == ast.OperationTypeSubscription { return graphql.Subscribe(p) + } else { + res := graphql.Do(p) + return sendOneResultAndClose(res) } - - res := graphql.Do(p) - return sendOneResultAndClose(res) } func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { diff --git a/gql_test.go b/gql_test.go index 9b3b131..05778c4 100644 --- a/gql_test.go +++ b/gql_test.go @@ -1,109 +1,34 @@ package graphvent -/*import ( +import ( "testing" - "time" "fmt" "encoding/json" "io" "net/http" "net" "crypto/tls" - "crypto/rand" - "crypto/ed25519" "bytes" "golang.org/x/net/websocket" "github.com/google/uuid" ) -func TestGQLAuth(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) - - listener_1 := NewListenerExt(10) - node_1, err := NewNode(ctx, nil, "Base", 10, nil, listener_1) - fatalErr(t, err) - - listener_2 := NewListenerExt(10) - node_2, err := NewNode(ctx, nil, "Base", 10, nil, listener_2) - fatalErr(t, err) - - auth_header, err := AuthB64(node_1.Key, node_2.Key.Public().(ed25519.PublicKey)) - fatalErr(t, err) - - auth, err := ParseAuthB64(auth_header, node_2.Key) - fatalErr(t, err) - - err = ValidateAuthorization(Authorization{ - AuthInfo: auth.AuthInfo, - Key: auth.Key.Public().(ed25519.PublicKey), - }, time.Second) - fatalErr(t, err) - - ctx.Log.Logf("test", "AUTH: %+v", auth) -} - func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "gqlws", "gql"}) - - pub, gql_key, err := ed25519.GenerateKey(rand.Reader) - fatalErr(t, err) - gql_id := KeyID(pub) - - group_policy_1 := NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[ReadSignal]()): Tree{ - SerializedType(ExtTypeFor[GroupExt]()): Tree{ - SerializedType(GetFieldTag("members")): Tree{}, - }, - }, - SerializedType(SignalTypeFor[ReadResultSignal]()): nil, - SerializedType(SignalTypeFor[ErrorSignal]()): nil, - }) - - group_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{ - gql_id: { - "test_group": { - SerializedType(SignalTypeFor[LinkSignal]()): nil, - SerializedType(SignalTypeFor[LockSignal]()): nil, - SerializedType(SignalTypeFor[StatusSignal]()): nil, - SerializedType(SignalTypeFor[ReadSignal]()): nil, - }, - }, - }) - - user_policy_1 := NewAllNodesPolicy(Tree{ - SerializedType(SignalTypeFor[ReadResultSignal]()): nil, - SerializedType(SignalTypeFor[ErrorSignal]()): nil, - }) - - user_policy_2 := NewMemberOfPolicy(map[NodeID]map[string]Tree{ - gql_id: { - "test_group": { - SerializedType(SignalTypeFor[LinkSignal]()): nil, - SerializedType(SignalTypeFor[ReadSignal]()): nil, - SerializedType(SignalTypeFor[LockSignal]()): nil, - }, - }, - }) + ctx := logTestContext(t, []string{"test", "gqlws", "gql", "gql_subscribe"}) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) listener_ext := NewListenerExt(10) - n1, err := NewNode(ctx, nil, "Base", 10, []Policy{user_policy_2, user_policy_1}, NewLockableExt(nil)) + n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) fatalErr(t, err) - gql, err := NewNode(ctx, gql_key, "Base", 10, []Policy{group_policy_2, group_policy_1}, - NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[string][]NodeID{"test_group": {n1.ID, gql_id}}), listener_ext) + gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext) fatalErr(t, err) ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StatusSignal) bool { - return sig.Source == gql_id - }) - fatalErr(t, err) - skipVerifyTransport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -113,22 +38,16 @@ func TestGQLServer(t *testing.T) { ws_url := fmt.Sprintf("ws://127.0.0.1:%d/gqlws", port) req_1 := GQLPayload{ - Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }", + Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type } }", Variables: map[string]interface{}{ "id": n1.ID.String(), }, } req_2 := GQLPayload{ - Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { SubGroups { Name, Members { ID } } , Listen, Requirements { ID, TypeHash Owner { ID } } } } }", - Variables: map[string]interface{}{ - "id": gql.ID.String(), - }, + Query: "query Self { Self { ID, Type } }", } - auth_header, err := AuthB64(n1.Key, gql.Key.Public().(ed25519.PublicKey)) - fatalErr(t, err) - SendGQL := func(payload GQLPayload) []byte { ser, err := json.MarshalIndent(&payload, "", " ") fatalErr(t, err) @@ -137,7 +56,6 @@ func TestGQLServer(t *testing.T) { req, err := http.NewRequest("GET", url, req_data) fatalErr(t, err) - req.Header.Add("Authorization", auth_header) resp, err := client.Do(req) fatalErr(t, err) @@ -154,7 +72,7 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "RESP_2: %s", resp_2) sub_1 := GQLPayload{ - Query: "subscription { Self { ID, TypeHash, ... on Lockable { Requirements { ID }}}}", + Query: "subscription Self { Self { ID, Type } }", } SubGQL := func(payload GQLPayload) { @@ -174,11 +92,9 @@ func TestGQLServer(t *testing.T) { init := struct{ ID uuid.UUID `json:"id"` Type string `json:"type"` - Payload payload_struct `json:"payload"` }{ uuid.New(), "connection_init", - payload_struct{ auth_header }, } ser, err := json.Marshal(&init) @@ -213,11 +129,12 @@ func TestGQLServer(t *testing.T) { fatalErr(t, err) ctx.Log.Logf("test", "SUB: %s", resp[:n]) - msgs := Messages{} - test_changes := Changes{} - AddChange[GQLExt](test_changes, "state") - msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, test_changes)) - err = ctx.Send(msgs) + err = ctx.Send(gql, []SendMsg{{ + Dest: gql.ID, + Signal: NewStatusSignal(gql.ID, map[ExtType]Changes{ + ExtTypeFor[GQLExt](): {"state"}, + }), + }}) fatalErr(t, err) n, err = ws.Read(resp) @@ -228,58 +145,25 @@ func TestGQLServer(t *testing.T) { } SubGQL(sub_1) - - msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal()) - err = ctx.Send(msgs) - fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { - return sig.Source == gql_id - }) - fatalErr(t, err) } func TestGQLDB(t *testing.T) { ctx := logTestContext(t, []string{"test", "db", "node"}) - u1, err := NewNode(ctx, nil, "Base", 10, nil) - fatalErr(t, err) - - ctx.Log.Logf("test", "U1_ID: %s", u1.ID) - gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) listener_ext := NewListenerExt(10) - gql, err := NewNode(ctx, nil, "Base", 10, nil, - gql_ext, - listener_ext, - NewGroupExt(nil)) - fatalErr(t, err) - ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal()) - err = ctx.Send(msgs) + gql, err := NewNode(ctx, nil, "Base", 10, gql_ext, listener_ext) fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { - return sig.Source == gql.ID - }) + ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) + + err = ctx.Unload(gql.ID) fatalErr(t, err) - // Clear all loaded nodes from the context so it loads them from the database - ctx.nodeMap = map[NodeID]*Node{} - gql_loaded, err := LoadNode(ctx, gql.ID) + gql_loaded, err := ctx.Load(gql.ID) fatalErr(t, err) listener_ext, err = GetExt[ListenerExt](gql_loaded) fatalErr(t, err) - msgs = Messages{} - msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded, nil, NewStopSignal()) - err = ctx.Send(msgs) - fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { - return sig.Source == gql_loaded.ID - }) - fatalErr(t, err) } -*/ diff --git a/listener.go b/listener.go index d2b7633..fbdc4a5 100644 --- a/listener.go +++ b/listener.go @@ -10,12 +10,35 @@ type ListenerExt struct { Chan chan Signal } +type LoadedSignal struct { + SignalHeader +} + +func NewLoadedSignal() *LoadedSignal { + return &LoadedSignal{ + SignalHeader: NewSignalHeader(), + } +} + +type UnloadedSignal struct { + SignalHeader +} + +func NewUnloadedSignal() *UnloadedSignal { + return &UnloadedSignal{ + SignalHeader: NewSignalHeader(), + } +} + func (ext *ListenerExt) Load(ctx *Context, node *Node) error { ext.Chan = make(chan Signal, ext.Buffer) + ext.Chan <- NewLoadedSignal() return nil } func (ext *ListenerExt) Unload(ctx *Context, node *Node) { + ext.Chan <- NewUnloadedSignal() + close(ext.Chan) } // Create a new listener extension with a given buffer size diff --git a/node.go b/node.go index c55531d..99da216 100644 --- a/node.go +++ b/node.go @@ -4,7 +4,6 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha512" - "encoding/binary" "fmt" "reflect" "sync/atomic" @@ -193,12 +192,11 @@ func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) { } } -func runNode(ctx *Context, node *Node) { +func runNode(ctx *Context, node *Node, started chan error) { ctx.Log.Logf("node", "RUN_START: %s", node.ID) - err := nodeLoop(ctx, node) + err := nodeLoop(ctx, node, started) if err != nil { ctx.Log.Logf("node", "%s runNode err %s", node.ID, err) - panic(err) } ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) } @@ -236,12 +234,25 @@ func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType } // Main Loop for nodes -func nodeLoop(ctx *Context, node *Node) error { - started := node.Active.CompareAndSwap(false, true) - if started == false { +func nodeLoop(ctx *Context, node *Node, started chan error) error { + is_started := node.Active.CompareAndSwap(false, true) + if is_started == false { return fmt.Errorf("%s is already started, will not start again", node.ID) + } else { + ctx.Log.Logf("node", "Set %s active", node.ID) + } + + for _, extension := range(node.Extensions) { + err := extension.Load(ctx, node) + if err != nil { + node.Active.Store(false) + ctx.Log.Logf("node", "Failed to load extension %s on node %s", reflect.TypeOf(extension), node.ID) + return err + } } + started <- nil + run := true for run == true { var signal Signal @@ -422,9 +433,13 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size ext_map := map[ExtType]Extension{} for _, ext := range(extensions) { + if ext == nil { + return nil, fmt.Errorf("Cannot create node with nil extension") + } + ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext).Elem()] if exists == false { - return nil, fmt.Errorf(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext))) + return nil, fmt.Errorf("%+v is not a known Extension", reflect.TypeOf(ext)) } _, exists = ext_map[ext_type.ExtType] if exists == true { @@ -456,25 +471,14 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size return nil, err } - // Load each extension before starting the main loop - for _, extension := range(node.Extensions) { - err := extension.Load(ctx, node) - if err != nil { - return nil, err - } - } - ctx.AddNode(id, node) - go runNode(ctx, node) + started := make(chan error, 1) + go runNode(ctx, node, started) - return node, nil -} + err = <- started + if err != nil { + return nil, err + } -var extension_suffix = []byte{0xEE, 0xFF, 0xEE, 0xFF} -var signal_queue_suffix = []byte{0xAB, 0xBA, 0xAB, 0xBA} -func ExtTypeSuffix(ext_type ExtType) []byte { - ret := make([]byte, 12) - copy(ret[0:4], extension_suffix) - binary.BigEndian.PutUint64(ret[4:], uint64(ext_type)) - return ret + return node, nil }