From c4df57a9325ac7260c46e15f341a0712beaa89ad Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 14 Oct 2023 15:05:23 -0600 Subject: [PATCH] Added Authorization to not pass node private keys --- acl.go | 12 +-- acl_test.go | 2 +- go.mod | 1 + go.sum | 2 + gql.go | 269 +++++++++++++++++++++++++++++++--------------- gql_node.go | 2 +- gql_signal.go | 2 +- gql_test.go | 28 ++++- group.go | 8 +- group_test.go | 4 +- lockable.go | 42 ++++---- lockable_test.go | 4 +- node.go | 84 +++++++++++---- node_test.go | 4 +- policy.go | 4 +- serialize_test.go | 2 +- 16 files changed, 316 insertions(+), 154 deletions(-) diff --git a/acl.go b/acl.go index 677fb34..945ca60 100644 --- a/acl.go +++ b/acl.go @@ -80,7 +80,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed allow") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(info.ID), acl_info.Source) + messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID)) changes = changes.Add("acl_passed") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { @@ -89,7 +89,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } else if acl_info.Counter == 0 { delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed deny") - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(info.ID, "acl_denied"), acl_info.Source) + messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied")) changes = changes.Add("acl_blocked") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { @@ -133,7 +133,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa if denied == true { ctx.Log.Logf("acl", "Request denied") - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "acl_denied"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied")) } else if acl_messages != nil { ctx.Log.Logf("acl", "Request pending") changes = changes.Add("acl_pending") @@ -168,7 +168,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else { ctx.Log.Logf("acl", "Request allowed") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) } // Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request case *ACLTimeoutSignal: @@ -176,7 +176,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa if exists == true { delete(ext.PendingACLs, sig.ReqID) ctx.Log.Logf("acl", "Request timeout deny") - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ReqID, "acl_timeout"), acl_info.Source) + messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout")) changes = changes.Add("acl_timeout") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { @@ -210,7 +210,7 @@ func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tr messages := Messages{} for _, proxy := range(policy.Proxies) { - messages = messages.Add(ctx, node.ID, node.Key, NewACLSignal(principal_id, action), proxy) + messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action)) } return messages, Pending diff --git a/acl_test.go b/acl_test.go index dda44b7..713207d 100644 --- a/acl_test.go +++ b/acl_test.go @@ -44,7 +44,7 @@ func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *No fatalErr(t, err) messages := Messages{} - messages = messages.Add(ctx, source.ID, source.Key, signal, destination.ID) + messages = messages.Add(ctx, destination.ID, source, nil, signal) fatalErr(t, ctx.Send(messages)) response, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) diff --git a/go.mod b/go.mod index 7ffa66c..f56e843 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( ) require ( + filippo.io/edwards25519 v1.0.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect diff --git a/go.sum b/go.sum index d7bad6d..97a2b05 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= +filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= diff --git a/gql.go b/gql.go index c921700..3cb51dc 100644 --- a/gql.go +++ b/gql.go @@ -1,35 +1,42 @@ package graphvent import ( - "time" - "net" - "net/http" - "github.com/graphql-go/graphql" - "github.com/graphql-go/graphql/language/parser" - "github.com/graphql-go/graphql/language/source" - "github.com/graphql-go/graphql/language/ast" + "bytes" "context" - "encoding/json" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" "encoding/base64" + "encoding/json" + "fmt" "io" + "net" + "net/http" "reflect" - "fmt" + "strings" "sync" + "time" + + "filippo.io/edwards25519" + "crypto/sha512" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" - "strings" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/ecdh" - "crypto/ed25519" - "crypto/rand" - "crypto/x509" - //"crypto/tls" + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/parser" + "github.com/graphql-go/graphql/language/source" + "crypto/x509/pkix" - "math/big" "encoding/pem" + "math/big" + "github.com/google/uuid" - "slices" ) func AuthorizationHeader(node *Node) (string, error) { @@ -353,80 +360,194 @@ type ResolveContext struct { // The state data for the node processing this request Ext *GQLExt - // ID of the user that made this request - User NodeID - // Cache of resolved nodes NodeCache map[NodeID]NodeResult - // Key for the user that made this request, to sign resolver requests - // TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards - Key ed25519.PrivateKey + // Authorization from the user that started this request + Authorization *ClientAuthorization } -func ParseAuthB64(auth_base64 string) (ed25519.PrivateKey, NodeID, error) { - auth_bytes, err := base64.StdEncoding.DecodeString(auth_base64) +func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) { + token_start := time.Now() + token_start_bytes, err := token_start.MarshalBinary() if err != nil { - return nil, NodeID{}, err + return "", err } - idx := slices.Index(auth_bytes, ':') - if idx == -1 { - return nil, NodeID{}, fmt.Errorf("No colon in auth") + session_key_public, session_key_private, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return "", err } - id_base64 := auth_bytes[:idx] - key_base64 := auth_bytes[idx+1:] + session_h := sha512.Sum512(session_key_private.Seed()) + ecdh_client, err := ECDH.NewPrivateKey(session_h[:32]) + if err != nil { + return "", err + } - id, err := ParseIDB64(string(id_base64)) + server_point, err := (&edwards25519.Point{}).SetBytes(server_pubkey) if err != nil { - return nil, NodeID{}, err + return "", err } - key, err := ParseKeyB64(string(key_base64)) + ecdh_server, err := ECDH.NewPublicKey(server_point.BytesMontgomery()) if err != nil { - return nil, NodeID{}, err + return "", err } - key_id := KeyID(key.Public().(ed25519.PublicKey)) - if key_id != id { - return nil, NodeID{}, fmt.Errorf("key_id != id(%s != %s)", key_id, id) + + secret, err := ecdh_client.ECDH(ecdh_server) + if err != nil { + return "", err } - return key, id, nil + if len(secret) != 32 { + return "", fmt.Errorf("ECDH secret not 32 bytes(for AES-256): %d bytes long", len(secret)) + } + + block, err := aes.NewCipher(secret) + if err != nil { + return "", err + } + + iv := make([]byte, block.BlockSize()) + iv_len, err := rand.Reader.Read(iv) + if err != nil { + return "", err + } else if iv_len != block.BlockSize() { + return "", fmt.Errorf("Not enough iv bytes read: %d", iv_len) + } + + var key_encrypted bytes.Buffer + stream := cipher.NewOFB(block, iv) + writer := &cipher.StreamWriter{S: stream, W: &key_encrypted} + + bytes_written, err := writer.Write(session_key_private.Seed()) + if err != nil { + return "", err + } else if bytes_written != len(ecdh_client.Bytes()) { + return "", fmt.Errorf("wrong number of bytes encrypted %d/%d", bytes_written, len(ecdh_client.Bytes())) + } + + digest := append(session_key_public, token_start_bytes...) + signature, err := client_key.Sign(rand.Reader, digest, crypto.Hash(0)) + if err != nil { + return "", err + } + + start_b64 := base64.StdEncoding.EncodeToString(token_start_bytes) + iv_b64 := base64.StdEncoding.EncodeToString(iv) + encrypted_b64 := base64.StdEncoding.EncodeToString(key_encrypted.Bytes()) + key_b64 := base64.StdEncoding.EncodeToString(session_key_public) + sig_b64 := base64.StdEncoding.EncodeToString(signature) + id_b64 := base64.StdEncoding.EncodeToString(client_key.Public().(ed25519.PublicKey)) + + return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil } -func ParseKeyB64(key_base64 string) (ed25519.PrivateKey, error) { - key_bytes, err := base64.StdEncoding.DecodeString(key_base64) +func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) { + joined_b64, err := base64.StdEncoding.DecodeString(auth_base64) if err != nil { return nil, err } - key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) + auth_parts := strings.Split(string(joined_b64), ":") + if len(auth_parts) != 6 { + return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts)) + } + + id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0]) if err != nil { return nil, err } - key, ok := key_raw.(ed25519.PrivateKey) - if ok == false { - return nil, fmt.Errorf("parsed key wrong type: %s", reflect.TypeOf(key_raw)) + iv, err := base64.StdEncoding.DecodeString(auth_parts[1]) + if err != nil { + return nil, err } - return key, nil -} + public_key, err := base64.StdEncoding.DecodeString(auth_parts[2]) + if err != nil { + return nil, err + } -func ParseIDB64(id_base64 string) (NodeID, error) { - id_bytes, err := base64.StdEncoding.DecodeString(id_base64) + key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3]) if err != nil { - return NodeID{}, err + return nil, err } - auth_id, err := IDFromBytes(id_bytes) + start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4]) if err != nil { - return NodeID{}, err + return nil, err } - return auth_id, nil + signature, err := base64.StdEncoding.DecodeString(auth_parts[5]) + if err != nil { + return nil, err + } + + var start time.Time + err = start.UnmarshalBinary(start_bytes) + if err != nil { + return nil, err + } + + client_id := ed25519.PublicKey(id_bytes) + if err != nil { + return nil, err + } + + client_point, err := (&edwards25519.Point{}).SetBytes(public_key) + if err != nil { + return nil, err + } + + ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery()) + if err != nil { + return nil, err + } + + h := sha512.Sum512(server_id.Seed()) + ecdh_server, err := ECDH.NewPrivateKey(h[:32]) + if err != nil { + return nil, err + } + + secret, err := ecdh_server.ECDH(ecdh_client) + if err != nil { + return nil, err + } else if len(secret) != 32 { + return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret)) + } + + block, err := aes.NewCipher(secret) + if err != nil { + return nil, err + } + + encrypted_reader := bytes.NewReader(key_encrypted) + stream := cipher.NewOFB(block, iv) + reader := cipher.StreamReader{S: stream, R: encrypted_reader} + var decrypted_key bytes.Buffer + _, err = io.Copy(&decrypted_key, reader) + if err != nil { + return nil, err + } + + session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes()) + digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...) + if ed25519.Verify(client_id, digest, signature) == false { + return nil, fmt.Errorf("Failed to verify digest/signature against client_id") + } + + return &ClientAuthorization{ + AuthInfo: AuthInfo{ + Identity: client_id, + Start: start, + Signature: signature, + }, + Key: session_key, + }, nil } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) { @@ -438,8 +559,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveCon GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext), NodeCache: map[NodeID]NodeResult{}, Server: server, - User: NodeID{}, - Key: nil, + Authorization: nil, }, nil } @@ -453,43 +573,21 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - id_b64, key_b64, ok := r.BasicAuth() - if ok == false { - ctx.Log.Logf("gql", "GQL_AUTH_BASIC_MISSING") - json.NewEncoder(w).Encode(fmt.Errorf("Failed to get auth headers")) - return - } - - auth_id, err := ParseIDB64(id_b64) + auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err) - json.NewEncoder(w).Encode(fmt.Errorf("Failed to parse auth_id: %s", id_b64)) - return - } - - key, err := ParseKeyB64(key_b64) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_KEY_PARSE_ERROR: %s", err) - json.NewEncoder(w).Encode(fmt.Errorf("Failed to parse key: %s", key_b64)) - return - } - - key_id := KeyID(key.Public().(ed25519.PublicKey)) - if auth_id != key_id { - ctx.Log.Logf("gql", "GQL_AUTH_ERR: key_id != auth_id: %s != %s", key_id, auth_id) - json.NewEncoder(w).Encode(fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id)) + json.NewEncoder(w).Encode(GQLUnauthorized("")) return } resolve_context, err := NewResolveContext(ctx, server, gql_ext) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) - json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) + json.NewEncoder(w).Encode(GQLUnauthorized("")) return } - resolve_context.Key = key - resolve_context.User = key_id + resolve_context.Authorization = auth req_ctx := context.Background() req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) @@ -634,14 +732,13 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon break } - key, key_id, err := ParseAuthB64(connection_params.Payload.Token) + authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key) if err != nil { ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err) break } - resolve_context.User = key_id - resolve_context.Key = key + resolve_context.Authorization = authorization conn_state = "ready" err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) diff --git a/gql_node.go b/gql_node.go index 2dcf2b1..1e2b426 100644 --- a/gql_node.go +++ b/gql_node.go @@ -111,7 +111,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ ctx.Context.Log.Logf("gql", "READ_SIGNAL for %s - %+v", id, read_signal) // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error msgs := Messages{} - msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id) + msgs = msgs.Add(ctx.Context, id, ctx.Server, ctx.Authorization, read_signal) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) resp_channels[read_signal.ID()] = response_chan diff --git a/gql_signal.go b/gql_signal.go index c4fe9d2..a5607f2 100644 --- a/gql_signal.go +++ b/gql_signal.go @@ -166,7 +166,7 @@ func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, sig return nil, err } msgs := Messages{} - msgs = msgs.Add(ctx.Context, ctx.User, ctx.Key, signal, send_id) + msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal) response_chan := ctx.Ext.GetResponseChannel(signal.ID()) err = ctx.Context.Send(msgs) diff --git a/gql_test.go b/gql_test.go index 466da79..a137d7c 100644 --- a/gql_test.go +++ b/gql_test.go @@ -18,6 +18,26 @@ import ( "github.com/google/uuid" ) +func TestGQLAuth(t *testing.T) { + ctx := logTestContext(t, []string{"test"}) + + listener_1 := NewListenerExt(10) + node_1, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_1) + fatalErr(t, err) + + listener_2 := NewListenerExt(10) + node_2, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_2) + fatalErr(t, err) + + auth_header, err := AuthB64(node_1.Key, node_2.Key.Public().(ed25519.PublicKey)) + fatalErr(t, err) + + auth, err := ParseAuthB64(auth_header, node_2.Key) + fatalErr(t, err) + + ctx.Log.Logf("test", "AUTH: %+v", auth) +} + func TestGQLServer(t *testing.T) { ctx := logTestContext(t, []string{"test", "gqlws"}) @@ -196,7 +216,7 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "SUB: %s", resp[:n]) msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStatusSignal(gql.ID, Changes{"test_status"}), gql.ID) + msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, Changes{"test_status"})) err = ctx.Send(msgs) fatalErr(t, err) @@ -210,7 +230,7 @@ func TestGQLServer(t *testing.T) { SubGQL(sub_1) msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) + msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal()) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { @@ -241,7 +261,7 @@ func TestGQLDB(t *testing.T) { ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) + msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal()) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { @@ -256,7 +276,7 @@ func TestGQLDB(t *testing.T) { listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType) fatalErr(t, err) msgs = Messages{} - msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID) + msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded, nil, NewStopSignal()) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { diff --git a/group.go b/group.go index 77d81c6..9bb613c 100644 --- a/group.go +++ b/group.go @@ -65,19 +65,19 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig switch sig := signal.(type) { case *AddMemberSignal: if slices.Contains(ext.Members, sig.MemberID) == true { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "already_member"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member")) } else { ext.Members = append(ext.Members, sig.MemberID) - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) changes = changes.Add("member_added") } case *RemoveMemberSignal: idx := slices.Index(ext.Members, sig.MemberID) if idx == -1 { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "not_member"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member")) } else { ext.Members = slices.Delete(ext.Members, idx, idx+1) - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) changes = changes.Add("member_removed") } } diff --git a/group_test.go b/group_test.go index 9b7b623..b09f15f 100644 --- a/group_test.go +++ b/group_test.go @@ -16,7 +16,7 @@ func TestGroupAdd(t *testing.T) { add_member_signal := NewAddMemberSignal(user_id) messages := Messages{} - messages = messages.Add(ctx, group.ID, group.Key, add_member_signal, group.ID) + messages = messages.Add(ctx, group.ID, group, nil, add_member_signal) fatalErr(t, ctx.Send(messages)) _, err = WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id) @@ -27,7 +27,7 @@ func TestGroupAdd(t *testing.T) { }) messages = Messages{} - messages = messages.Add(ctx, group.ID, group.Key, read_signal, group.ID) + messages = messages.Add(ctx, group.ID, group, nil, read_signal) fatalErr(t, ctx.Send(messages)) response, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id) diff --git a/lockable.go b/lockable.go index 2a62f16..75efb2d 100644 --- a/lockable.go +++ b/lockable.go @@ -53,14 +53,14 @@ func NewLockableExt(requirements []NodeID) *LockableExt { func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { messages := Messages{} signal := NewLockSignal("unlock") - messages = messages.Add(ctx, node.ID, node.Key, signal, node.ID) + messages = messages.Add(ctx, node.ID, node, nil, signal) return signal.ID(), ctx.Send(messages) } func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { messages := Messages{} signal := NewLockSignal("lock") - messages = messages.Add(ctx, node.ID, node.Key, signal, node.ID) + messages = messages.Add(ctx, node.ID, node, nil, signal) return signal.ID(), ctx.Send(messages) } @@ -86,7 +86,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI ext.Requirements[id] = req_info ctx.Log.Logf("lockable", "SENT_ABORT_UNLOCK: %s to %s", lock_signal.ID(), id) - messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) + messages = messages.Add(ctx, id, node, nil, lock_signal) } } } @@ -106,7 +106,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID case "add": _, exists := ext.Requirements[signal.NodeID] if exists == true { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement")) } else { if ext.Requirements == nil { ext.Requirements = map[NodeID]ReqInfo{} @@ -116,22 +116,22 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID uuid.UUID{}, } changes = changes.Add("requirement_added") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), source) + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } case "remove": _, exists := ext.Requirements[signal.NodeID] if exists == false { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "can't link: not_requirement"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement")) } else { delete(ext.Requirements, signal.NodeID) changes = changes.Add("requirement_removed") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), source) + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } default: - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "unknown_action")) } } else { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) } return messages, changes } @@ -168,7 +168,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ext.State = Locked ext.Owner = ext.PendingOwner changes = changes.Add("locked") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), *ext.Owner) + messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) } else { changes = changes.Add("partial_lock") ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) @@ -178,7 +178,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod info.State = Unlocking info.MsgID = lock_signal.ID() ext.Requirements[source] = info - messages = messages.Add(ctx, node.ID, node.Key, lock_signal, source) + messages = messages.Add(ctx, source, node, nil, lock_signal) } } else if info.State == Unlocking { info.State = Unlocked @@ -202,10 +202,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ext.Owner = ext.PendingOwner ext.ReqID = nil changes = changes.Add("unlocked") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), previous_owner) + messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) } else if old_state == AbortingLock { changes = changes.Add("lock_aborted") - messages = messages.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner) + messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) ext.PendingOwner = ext.Owner } } else { @@ -232,7 +232,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = &new_owner ext.Owner = &new_owner changes = changes.Add("locked") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) + messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else { ext.State = Locking id := signal.ID() @@ -249,11 +249,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID info.State = Locking info.MsgID = lock_signal.ID() ext.Requirements[id] = info - messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) + messages = messages.Add(ctx, id, node, nil, lock_signal) } } } else { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) } case "unlock": if ext.State == Locked { @@ -263,7 +263,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = nil ext.Owner = nil changes = changes.Add("unlocked") - messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) + messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else if source == *ext.Owner { ext.State = Unlocking id := signal.ID() @@ -279,11 +279,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID info.State = Unlocking info.MsgID = lock_signal.ID() ext.Requirements[id] = info - messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) + messages = messages.Add(ctx, id, node, nil, lock_signal) } } } else { - messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_locked")) } default: ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) @@ -301,13 +301,13 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal case Up: if ext.Owner != nil { if *ext.Owner != node.ID { - messages = messages.Add(ctx, node.ID, node.Key, signal, *ext.Owner) + messages = messages.Add(ctx, *ext.Owner, node, nil, signal) } } case Down: for requirement := range(ext.Requirements) { - messages = messages.Add(ctx, node.ID, node.Key, signal, requirement) + messages = messages.Add(ctx, requirement, node, nil, signal) } case Direct: diff --git a/lockable_test.go b/lockable_test.go index fc10e16..de6e1cb 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -44,7 +44,7 @@ func TestLink(t *testing.T) { msgs := Messages{} link_signal := NewLinkSignal("add", l2.ID) - msgs = msgs.Add(ctx, l1.ID, l1.Key, link_signal, l1.ID) + msgs = msgs.Add(ctx, l1.ID, l1, nil, link_signal) err = ctx.Send(msgs) fatalErr(t, err) @@ -60,7 +60,7 @@ func TestLink(t *testing.T) { msgs = Messages{} unlink_signal := NewLinkSignal("remove", l2.ID) - msgs = msgs.Add(ctx, l1.ID, l1.Key, unlink_signal, l1.ID) + msgs = msgs.Add(ctx, l1.ID, l1, nil, unlink_signal) err = ctx.Send(msgs) fatalErr(t, err) diff --git a/node.go b/node.go index 6e28d42..d20edca 100644 --- a/node.go +++ b/node.go @@ -282,29 +282,27 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err) continue } - src_id_ser, err := msg.Source.MarshalBinary() + src_id_ser, err := KeyID(msg.Source).MarshalBinary() if err != nil { ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err) continue } sig_data := append(dst_id_ser, src_id_ser...) sig_data = append(sig_data, ser...) - validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) + validated := ed25519.Verify(msg.Source, sig_data, msg.Signature) if validated == false { - println(fmt.Sprintf("SIGNAL: %s", msg.Signal)) - println(fmt.Sprintf("VERIFY_DIGEST: %+v", sig_data)) ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg) continue } - princ_id := KeyID(msg.Principal) + princ_id := KeyID(msg.Source) if princ_id != node.ID { pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission()) if resp == Deny { ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal) ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) + msgs = msgs.Add(ctx, KeyID(msg.Source), node, nil, NewErrorSignal(msg.Signal.ID(), "acl denied")) ctx.Send(msgs) continue } else if resp == Pending { @@ -327,7 +325,7 @@ func nodeLoop(ctx *Context, node *Node) error { Principal: princ_id, Responses: []ResponseSignal{}, Signal: msg.Signal, - Source: msg.Source, + Source: KeyID(msg.Source), } ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs) ctx.Send(msgs) @@ -340,7 +338,7 @@ func nodeLoop(ctx *Context, node *Node) error { } signal = msg.Signal - source = msg.Source + source = KeyID(msg.Source) case <-node.TimeoutChan: signal = node.NextSignal.Signal @@ -408,7 +406,7 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) // Send the denied response msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) + msgs = msgs.Add(ctx, req_info.Source, node, nil, NewErrorSignal(req_info.Signal.ID(), "acl_denied")) err := ctx.Send(msgs) if err != nil { ctx.Log.Logf("signal", "SEND_ERR: %s", err) @@ -434,7 +432,7 @@ func nodeLoop(ctx *Context, node *Node) error { node.Process(ctx, source, NewStoppedSignal(sig, node.ID)) } else { msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewStoppedSignal(sig, node.ID), source) + msgs = msgs.Add(ctx, node.ID, node, nil, NewStoppedSignal(sig, node.ID)) ctx.Send(msgs) } run = false @@ -442,7 +440,7 @@ func nodeLoop(ctx *Context, node *Node) error { case *ReadSignal: result := node.ReadFields(ctx, sig.Extensions) msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID(), node.ID, node.Type, result), source) + msgs = msgs.Add(ctx, source, node, nil, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)) ctx.Send(msgs) default: @@ -460,17 +458,52 @@ func nodeLoop(ctx *Context, node *Node) error { return nil } +type AuthInfo struct { + // The Node that issued the authorization + Identity ed25519.PublicKey + + // Time the authorization was generated + Start time.Time + + // Signature of Start + Principal with Identity private key + Signature []byte +} + +type AuthorizationToken struct { + AuthInfo + + // The private key generated by the client, encrypted with the servers public key + KeyEncrypted []byte +} + +type ClientAuthorization struct { + AuthInfo + + // The private key generated by the client + Key ed25519.PrivateKey +} + +// Authorization structs can be passed in a message that originated from a different node than the sender +type Authorization struct { + AuthInfo + + // The public key generated for this authorization + Key ed25519.PublicKey +} + type Message struct { - Source NodeID Dest NodeID - Principal ed25519.PublicKey + Source ed25519.PublicKey + + Authorization *Authorization + Signal Signal Signature []byte } type Messages []*Message -func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { - msg, err := NewMessage(ctx, dest, source, principal, signal) +func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages { + msg, err := NewMessage(ctx, dest, source, authorization, signal) if err != nil { panic(err) } else { @@ -479,7 +512,7 @@ func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateK return msgs } -func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { +func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) { signal_ser, err := SerializeAny(ctx, signal) if err != nil { return nil, err @@ -494,30 +527,39 @@ func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.Priv if err != nil { return nil, err } - source_ser, err := source.MarshalBinary() + source_ser, err := source.ID.MarshalBinary() if err != nil { return nil, err } sig_data := append(dest_ser, source_ser...) sig_data = append(sig_data, ser...) + var message_auth *Authorization = nil + if authorization != nil { + sig_data = append(sig_data, authorization.Signature...) + message_auth = &Authorization{ + authorization.AuthInfo, + authorization.Key.Public().(ed25519.PublicKey), + } + } - sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) + sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { return nil, err } return &Message{ Dest: dest, - Source: source, - Principal: principal.Public().(ed25519.PublicKey), + Source: source.Key.Public().(ed25519.PublicKey), + Authorization: message_auth, Signal: signal, Signature: sig, }, nil } + func (node *Node) Stop(ctx *Context) error { if node.Active.Load() { - msg, err := NewMessage(ctx, node.ID, node.ID, node.Key, NewStopSignal()) + msg, err := NewMessage(ctx, node.ID, node, nil, NewStopSignal()) if err != nil { return err } diff --git a/node_test.go b/node_test.go index 996a1a9..e2e51c4 100644 --- a/node_test.go +++ b/node_test.go @@ -23,7 +23,7 @@ func TestNodeDB(t *testing.T) { }) msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewStopSignal(), node.ID) + msgs = msgs.Add(ctx, node.ID, node, nil, NewStopSignal()) err = ctx.Send(msgs) fatalErr(t, err) @@ -71,7 +71,7 @@ func TestNodeRead(t *testing.T) { GroupExtType: {"members"}, }) msgs := Messages{} - msgs = msgs.Add(ctx, n2.ID, n2.Key, read_sig, n1.ID) + msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig) err = ctx.Send(msgs) fatalErr(t, err) diff --git a/policy.go b/policy.go index 69bd5f9..2fbeb79 100644 --- a/policy.go +++ b/policy.go @@ -147,9 +147,9 @@ func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tr } } } else { - msgs = msgs.Add(ctx, node.ID, node.Key, NewReadSignal(map[ExtType][]string{ + msgs = msgs.Add(ctx, id, node, nil, NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, - }), id) + })) } } return msgs, Pending diff --git a/serialize_test.go b/serialize_test.go index 701222a..68cd492 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -7,7 +7,7 @@ import ( ) func TestSerializeBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize"}) + ctx := logTestContext(t, []string{"test"}) testSerializeComparable[string](t, ctx, "test") testSerializeComparable[bool](t, ctx, true) testSerializeComparable[float32](t, ctx, 0.05)