diff --git a/gql_query.go b/gql_query.go index 28544c3..20070e3 100644 --- a/gql_query.go +++ b/gql_query.go @@ -68,7 +68,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ responses := []NodeResult{} for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout - response, err := WaitForSignal(ctx.Context, response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ + response, err := WaitForSignal(response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ return sig.ReqID() == sig_id }) if err != nil { diff --git a/gql_test.go b/gql_test.go index 9f8c675..accae7e 100644 --- a/gql_test.go +++ b/gql_test.go @@ -85,7 +85,7 @@ func TestGQLServer(t *testing.T) { err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_started" }) fatalErr(t, err) @@ -141,7 +141,7 @@ func TestGQLServer(t *testing.T) { msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_stopped" }) fatalErr(t, err) @@ -170,7 +170,7 @@ func TestGQLDB(t *testing.T) { msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql.ID }) fatalErr(t, err) @@ -192,7 +192,7 @@ func TestGQLDB(t *testing.T) { msgs = msgs.Add(gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID }) fatalErr(t, err) diff --git a/lockable_test.go b/lockable_test.go index a7f1f75..d68c4f6 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -48,12 +48,12 @@ func TestLink(t *testing.T) { err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { return sig.ID() == s.ID() }) fatalErr(t, err) - _, err = WaitForSignal(ctx, l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { + _, err = WaitForSignal(l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { return sig.ID() == s.ID() }) fatalErr(t, err) @@ -103,13 +103,13 @@ func TestLink10K(t *testing.T) { err = LockLockable(ctx, node) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { + _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { return sig.Str == "locked" }) fatalErr(t, err) for _, _ = range(reqs) { - _, err := WaitForSignal(ctx, listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool { + _, err := WaitForSignal(listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool { return sig.Str == "locked" }) fatalErr(t, err) @@ -147,15 +147,15 @@ func TestLock(t *testing.T) { err := LockLockable(ctx, l1) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) err = UnlockLockable(ctx, l1) diff --git a/node_test.go b/node_test.go index a9179e7..d6458e6 100644 --- a/node_test.go +++ b/node_test.go @@ -56,7 +56,7 @@ func TestNodeRead(t *testing.T) { err = ctx.Send(msgs) fatalErr(t, err) - res, err := WaitForSignal(ctx, n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { + res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { return true }) fatalErr(t, err) diff --git a/signal.go b/signal.go index 1d48fbb..7e023ca 100644 --- a/signal.go +++ b/signal.go @@ -49,7 +49,29 @@ type Signal interface { Permission() Tree } -func WaitForSignal[S Signal](ctx * Context, listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { +func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (Signal, error) { + var timeout_channel <- chan time.Time + if timeout > 0 { + timeout_channel = time.After(timeout) + } + + for true { + select { + case signal := <- listener: + if signal == nil { + return nil, fmt.Errorf("LISTENER_CLOSED") + } + if signal.ReqID() == req_id { + return signal, nil + } + case <-timeout_channel: + return nil, fmt.Errorf("LISTENER_TIMEOUT") + } + } + return nil, fmt.Errorf("UNREACHABLE") +} + +func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { var zero S var timeout_channel <- chan time.Time if timeout > 0 {