Added AuthorizedSignal wrapper to wrap a signal with an ecdsa signature to prove it originated from a different node

gql_cataclysm
noah metz 2023-08-01 14:09:29 -06:00
parent b9a2cceaf1
commit 42cd8f4188
6 changed files with 107 additions and 27 deletions

@ -165,9 +165,6 @@ func checkForAuthHeader(header http.Header) (string, bool) {
// Context passed to each resolve execution // Context passed to each resolve execution
type ResolveContext struct { type ResolveContext struct {
// ID generated for the context so the gql extension can route data to it
ID uuid.UUID
// Channels for the gql extension to route data to this context // Channels for the gql extension to route data to this context
Chans map[uuid.UUID]chan Signal Chans map[uuid.UUID]chan Signal
@ -184,13 +181,15 @@ type ResolveContext struct {
Ext *GQLExt Ext *GQLExt
// ID of the user that made this request // ID of the user that made this request
// TODO: figure out auth
User NodeID User NodeID
// Key for the user that made this request, to sign resolver requests
// TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards
Key *ecdsa.PrivateKey
} }
const GQL_RESOLVER_CHAN_SIZE = 10
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, _, ok := r.BasicAuth() username, key_bytes, ok := r.BasicAuth()
if ok == false { if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
} }
@ -200,14 +199,24 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
} }
key, err := x509.ParseECPrivateKey([]byte(key_bytes))
if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes)
}
key_id := KeyID(&key.PublicKey)
if auth_id != key_id {
return nil, fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id)
}
return &ResolveContext{ return &ResolveContext{
Ext: gql_ext, Ext: gql_ext,
ID: uuid.New(),
Chans: map[uuid.UUID]chan Signal{}, Chans: map[uuid.UUID]chan Signal{},
Context: ctx, Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
Server: server, Server: server,
User: auth_id, User: key_id,
Key: key,
}, nil }, nil
} }
@ -1050,7 +1059,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
} }
} else { } else {
ctx.Log.Logf("gql", "received error signal response %s with no mapped resolver", sig.UUID) ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
} }
} else if signal.Type() == ReadResultSignalType { } else if signal.Type() == ReadResultSignalType {
sig := signal.(ReadResultSignal) sig := signal.(ReadResultSignal)

@ -2,6 +2,7 @@ package graphvent
import ( import (
"time" "time"
"reflect" "reflect"
"fmt"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/ast"
"github.com/google/uuid" "github.com/google/uuid"
@ -50,13 +51,17 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
} }
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields) read_signal := NewReadSignal(ext_fields)
auth_signal, err := NewAuthorizedSignal(ctx.Key, read_signal)
if err != nil {
return nil, err
}
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, read_signal) err = ctx.Context.Send(ctx.Server.ID, id, auth_signal)
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID()) ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err return nil, err
@ -70,7 +75,14 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
if err != nil { if err != nil {
return nil, err return nil, err
} }
responses = append(responses, NodeResult{node_ids[sig_id], response.(ReadResultSignal)}) switch resp := response.(type) {
case ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp})
case ErrorSignal:
return nil, resp.Error
default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
}
} }
return responses, nil return responses, nil

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"net" "net"
"crypto/tls" "crypto/tls"
"crypto/x509"
"bytes" "bytes"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -66,7 +67,9 @@ func TestGQL(t *testing.T) {
req, err := http.NewRequest("GET", url, req_data) req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err) fatalErr(t, err)
req.SetBasicAuth(n1.ID.String(), "BAD_PASSWORD") key_bytes, err := x509.MarshalECPrivateKey(n1.Key)
fatalErr(t, err)
req.SetBasicAuth(n1.ID.String(), string(key_bytes))
resp, err := client.Do(req) resp, err := client.Do(req)
fatalErr(t, err) fatalErr(t, err)

@ -238,6 +238,35 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
} }
// Unwrap Authorized Signals
if signal.Type() == AuthorizedSignalType {
sig, ok := signal.(AuthorizedSignal)
if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal))
} else {
// Validate
sig_data, err := sig.Signal.Serialize()
if err != nil {
} else {
sig_hash := sha512.Sum512(sig_data)
validated := ecdsa.VerifyASN1(sig.Principal, sig_hash[:], sig.Signature)
if validated == true {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node)
if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), err))
} else {
// Unwrap the signal without changing the source
signal = sig.Signal
}
} else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), fmt.Errorf("failed to validate signature")))
}
}
}
}
// Handle special signal types // Handle special signal types
if signal.Type() == StopSignalType { if signal.Type() == StopSignalType {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil)) ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil))

@ -330,8 +330,9 @@ func (ext *ACLExt) Field(name string) interface{} {
var ErrorSignalAction = Action{"ERROR_RESP"} var ErrorSignalAction = Action{"ERROR_RESP"}
var ReadResultSignalAction = Action{"READ_RESULT"} var ReadResultSignalAction = Action{"READ_RESULT"}
var AuthorizedSignalAction = Action{"AUTHORIZED_READ"}
var DefaultACLPolicies = []Policy{ var DefaultACLPolicies = []Policy{
NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction}), NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}),
} }
func NewACLExt(policies ...Policy) *ACLExt { func NewACLExt(policies ...Policy) *ACLExt {

@ -23,6 +23,7 @@ const (
LinkSignalType = "LINK" LinkSignalType = "LINK"
LockSignalType = "LOCK" LockSignalType = "LOCK"
ReadSignalType = "READ" ReadSignalType = "READ"
AuthorizedSignalType = "AUTHORIZED"
ReadResultSignalType = "READ_RESULT" ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START" LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH" ECDHSignalType = "ECDH"
@ -48,23 +49,17 @@ type Signal interface {
func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) { func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) {
timeout_channel := time.After(timeout) timeout_channel := time.After(timeout)
var err error = nil
var result Signal = nil
run := true
for run == true {
select { select {
case result=<-listener: case result:=<-listener:
if result.ID() == id { if result.ID() == id {
run = false return result, nil
} else {
return result, fmt.Errorf("WRONG_ID: %s", result.ID())
} }
case <-timeout_channel: case <-timeout_channel:
result = nil return nil, fmt.Errorf("timeout waiting for read response to %s", id)
err = fmt.Errorf("timeout waiting for read response to %s", id)
run = false
} }
} }
return result, err
}
func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) {
var zero S var zero S
@ -257,6 +252,37 @@ type ReadSignal struct {
Extensions map[ExtType][]string `json:"extensions"` Extensions map[ExtType][]string `json:"extensions"`
} }
type AuthorizedSignal struct {
BaseSignal
Principal *ecdsa.PublicKey
Signal Signal
Signature []byte
}
func (signal AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction
}
func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (AuthorizedSignal, error) {
sig_data, err := signal.Serialize()
if err != nil {
return AuthorizedSignal{}, err
}
sig_hash := sha512.Sum512(sig_data)
sig, err := ecdsa.SignASN1(rand.Reader, principal, sig_hash[:])
if err != nil {
return AuthorizedSignal{}, err
}
return AuthorizedSignal{
BaseSignal: NewDirectSignal(AuthorizedSignalType),
Principal: &principal.PublicKey,
Signal: signal,
Signature: sig,
}, nil
}
func (signal ReadSignal) Serialize() ([]byte, error) { func (signal ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(&signal)
} }