diff --git a/gql.go b/gql.go index 667e9cd..85d6498 100644 --- a/gql.go +++ b/gql.go @@ -165,9 +165,6 @@ func checkForAuthHeader(header http.Header) (string, bool) { // Context passed to each resolve execution type ResolveContext struct { - // ID generated for the context so the gql extension can route data to it - ID uuid.UUID - // Channels for the gql extension to route data to this context Chans map[uuid.UUID]chan Signal @@ -184,13 +181,15 @@ type ResolveContext struct { Ext *GQLExt // ID of the user that made this request - // TODO: figure out auth User NodeID + + // Key for the user that made this request, to sign resolver requests + // TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards + Key *ecdsa.PrivateKey } -const GQL_RESOLVER_CHAN_SIZE = 10 func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { - username, _, ok := r.BasicAuth() + username, key_bytes, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") } @@ -200,14 +199,24 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) } + key, err := x509.ParseECPrivateKey([]byte(key_bytes)) + if err != nil { + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes) + } + + key_id := KeyID(&key.PublicKey) + if auth_id != key_id { + return nil, fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id) + } + return &ResolveContext{ Ext: gql_ext, - ID: uuid.New(), Chans: map[uuid.UUID]chan Signal{}, Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, - User: auth_id, + User: key_id, + Key: key, }, nil } @@ -1050,7 +1059,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa } } else { - ctx.Log.Logf("gql", "received error signal response %s with no mapped resolver", sig.UUID) + ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig) } } else if signal.Type() == ReadResultSignalType { sig := signal.(ReadResultSignal) diff --git a/gql_query.go b/gql_query.go index 9f2c26c..b68d462 100644 --- a/gql_query.go +++ b/gql_query.go @@ -2,6 +2,7 @@ package graphvent import ( "time" "reflect" + "fmt" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" "github.com/google/uuid" @@ -50,13 +51,17 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ } // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) + auth_signal, err := NewAuthorizedSignal(ctx.Key, read_signal) + if err != nil { + return nil, err + } response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) resp_channels[read_signal.ID()] = response_chan node_ids[read_signal.ID()] = id - err = ctx.Context.Send(ctx.Server.ID, id, read_signal) + err = ctx.Context.Send(ctx.Server.ID, id, auth_signal) if err != nil { ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err @@ -70,7 +75,14 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ if err != nil { return nil, err } - responses = append(responses, NodeResult{node_ids[sig_id], response.(ReadResultSignal)}) + switch resp := response.(type) { + case ReadResultSignal: + responses = append(responses, NodeResult{node_ids[sig_id], resp}) + case ErrorSignal: + return nil, resp.Error + default: + return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) + } } return responses, nil diff --git a/gql_test.go b/gql_test.go index 3e41b72..e01d189 100644 --- a/gql_test.go +++ b/gql_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net" "crypto/tls" + "crypto/x509" "bytes" "github.com/google/uuid" ) @@ -66,7 +67,9 @@ func TestGQL(t *testing.T) { req, err := http.NewRequest("GET", url, req_data) fatalErr(t, err) - req.SetBasicAuth(n1.ID.String(), "BAD_PASSWORD") + key_bytes, err := x509.MarshalECPrivateKey(n1.Key) + fatalErr(t, err) + req.SetBasicAuth(n1.ID.String(), string(key_bytes)) resp, err := client.Do(req) fatalErr(t, err) diff --git a/node.go b/node.go index f3319a6..4000ea4 100644 --- a/node.go +++ b/node.go @@ -238,6 +238,35 @@ func nodeLoop(ctx *Context, node *Node) error { } } + // Unwrap Authorized Signals + if signal.Type() == AuthorizedSignalType { + sig, ok := signal.(AuthorizedSignal) + if ok == false { + ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal)) + } else { + // Validate + sig_data, err := sig.Signal.Serialize() + if err != nil { + } else { + sig_hash := sha512.Sum512(sig_data) + validated := ecdsa.VerifyASN1(sig.Principal, sig_hash[:], sig.Signature) + if validated == true { + err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) + if err != nil { + ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) + ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), err)) + } else { + // Unwrap the signal without changing the source + signal = sig.Signal + } + } else { + ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") + ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), fmt.Errorf("failed to validate signature"))) + } + } + } + } + // Handle special signal types if signal.Type() == StopSignalType { ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil)) diff --git a/policy.go b/policy.go index 1960861..93bac51 100644 --- a/policy.go +++ b/policy.go @@ -330,8 +330,9 @@ func (ext *ACLExt) Field(name string) interface{} { var ErrorSignalAction = Action{"ERROR_RESP"} var ReadResultSignalAction = Action{"READ_RESULT"} +var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} var DefaultACLPolicies = []Policy{ - NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction}), + NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}), } func NewACLExt(policies ...Policy) *ACLExt { diff --git a/signal.go b/signal.go index c800639..8c0d733 100644 --- a/signal.go +++ b/signal.go @@ -23,6 +23,7 @@ const ( LinkSignalType = "LINK" LockSignalType = "LOCK" ReadSignalType = "READ" + AuthorizedSignalType = "AUTHORIZED" ReadResultSignalType = "READ_RESULT" LinkStartSignalType = "LINK_START" ECDHSignalType = "ECDH" @@ -48,22 +49,16 @@ type Signal interface { func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) { timeout_channel := time.After(timeout) - var err error = nil - var result Signal = nil - run := true - for run == true { - select { - case result=<-listener: - if result.ID() == id { - run = false - } - case <-timeout_channel: - result = nil - err = fmt.Errorf("timeout waiting for read response to %s", id) - run = false + select { + case result:=<-listener: + if result.ID() == id { + return result, nil + } else { + return result, fmt.Errorf("WRONG_ID: %s", result.ID()) } + case <-timeout_channel: + return nil, fmt.Errorf("timeout waiting for read response to %s", id) } - return result, err } func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { @@ -257,6 +252,37 @@ type ReadSignal struct { Extensions map[ExtType][]string `json:"extensions"` } +type AuthorizedSignal struct { + BaseSignal + Principal *ecdsa.PublicKey + Signal Signal + Signature []byte +} + +func (signal AuthorizedSignal) Permission() Action { + return AuthorizedSignalAction +} + +func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (AuthorizedSignal, error) { + sig_data, err := signal.Serialize() + if err != nil { + return AuthorizedSignal{}, err + } + + sig_hash := sha512.Sum512(sig_data) + sig, err := ecdsa.SignASN1(rand.Reader, principal, sig_hash[:]) + if err != nil { + return AuthorizedSignal{}, err + } + + return AuthorizedSignal{ + BaseSignal: NewDirectSignal(AuthorizedSignalType), + Principal: &principal.PublicKey, + Signal: signal, + Signature: sig, + }, nil +} + func (signal ReadSignal) Serialize() ([]byte, error) { return json.Marshal(&signal) }