Added Authorization to not pass node private keys

gql_cataclysm
noah metz 2023-10-14 15:05:23 -06:00
parent 8c80ec9dd6
commit c4df57a932
16 changed files with 316 additions and 154 deletions

@ -80,7 +80,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
delete(ext.PendingACLs, info.ID) delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed allow") ctx.Log.Logf("acl", "Request delayed allow")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(info.ID), acl_info.Source) messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
changes = changes.Add("acl_passed") changes = changes.Add("acl_passed")
err := node.DequeueSignal(acl_info.TimeoutID) err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil { if err != nil {
@ -89,7 +89,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else if acl_info.Counter == 0 { } else if acl_info.Counter == 0 {
delete(ext.PendingACLs, info.ID) delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed deny") ctx.Log.Logf("acl", "Request delayed deny")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(info.ID, "acl_denied"), acl_info.Source) messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
changes = changes.Add("acl_blocked") changes = changes.Add("acl_blocked")
err := node.DequeueSignal(acl_info.TimeoutID) err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil { if err != nil {
@ -133,7 +133,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
if denied == true { if denied == true {
ctx.Log.Logf("acl", "Request denied") ctx.Log.Logf("acl", "Request denied")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "acl_denied"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
} else if acl_messages != nil { } else if acl_messages != nil {
ctx.Log.Logf("acl", "Request pending") ctx.Log.Logf("acl", "Request pending")
changes = changes.Add("acl_pending") changes = changes.Add("acl_pending")
@ -168,7 +168,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} }
} else { } else {
ctx.Log.Logf("acl", "Request allowed") ctx.Log.Logf("acl", "Request allowed")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
} }
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request // Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
case *ACLTimeoutSignal: case *ACLTimeoutSignal:
@ -176,7 +176,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
if exists == true { if exists == true {
delete(ext.PendingACLs, sig.ReqID) delete(ext.PendingACLs, sig.ReqID)
ctx.Log.Logf("acl", "Request timeout deny") ctx.Log.Logf("acl", "Request timeout deny")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ReqID, "acl_timeout"), acl_info.Source) messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
changes = changes.Add("acl_timeout") changes = changes.Add("acl_timeout")
err := node.DequeueSignal(acl_info.TimeoutID) err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil { if err != nil {
@ -210,7 +210,7 @@ func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tr
messages := Messages{} messages := Messages{}
for _, proxy := range(policy.Proxies) { for _, proxy := range(policy.Proxies) {
messages = messages.Add(ctx, node.ID, node.Key, NewACLSignal(principal_id, action), proxy) messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
} }
return messages, Pending return messages, Pending

@ -44,7 +44,7 @@ func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *No
fatalErr(t, err) fatalErr(t, err)
messages := Messages{} messages := Messages{}
messages = messages.Add(ctx, source.ID, source.Key, signal, destination.ID) messages = messages.Add(ctx, destination.ID, source, nil, signal)
fatalErr(t, ctx.Send(messages)) fatalErr(t, ctx.Send(messages))
response, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) response, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())

@ -12,6 +12,7 @@ require (
) )
require ( require (
filippo.io/edwards25519 v1.0.0 // indirect
github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect

@ -1,4 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=

267
gql.go

@ -1,35 +1,42 @@
package graphvent package graphvent
import ( import (
"time" "bytes"
"net"
"net/http"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/source"
"github.com/graphql-go/graphql/language/ast"
"context" "context"
"encoding/json" "crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt"
"io" "io"
"net"
"net/http"
"reflect" "reflect"
"fmt" "strings"
"sync" "sync"
"time"
"filippo.io/edwards25519"
"crypto/sha512"
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"strings" "github.com/graphql-go/graphql"
"crypto/ecdsa" "github.com/graphql-go/graphql/language/ast"
"crypto/elliptic" "github.com/graphql-go/graphql/language/parser"
"crypto/ecdh" "github.com/graphql-go/graphql/language/source"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
//"crypto/tls"
"crypto/x509/pkix" "crypto/x509/pkix"
"math/big"
"encoding/pem" "encoding/pem"
"math/big"
"github.com/google/uuid" "github.com/google/uuid"
"slices"
) )
func AuthorizationHeader(node *Node) (string, error) { func AuthorizationHeader(node *Node) (string, error) {
@ -353,80 +360,194 @@ type ResolveContext struct {
// The state data for the node processing this request // The state data for the node processing this request
Ext *GQLExt Ext *GQLExt
// ID of the user that made this request
User NodeID
// Cache of resolved nodes // Cache of resolved nodes
NodeCache map[NodeID]NodeResult NodeCache map[NodeID]NodeResult
// Key for the user that made this request, to sign resolver requests // Authorization from the user that started this request
// TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards Authorization *ClientAuthorization
Key ed25519.PrivateKey }
func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) {
token_start := time.Now()
token_start_bytes, err := token_start.MarshalBinary()
if err != nil {
return "", err
}
session_key_public, session_key_private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return "", err
}
session_h := sha512.Sum512(session_key_private.Seed())
ecdh_client, err := ECDH.NewPrivateKey(session_h[:32])
if err != nil {
return "", err
} }
func ParseAuthB64(auth_base64 string) (ed25519.PrivateKey, NodeID, error) { server_point, err := (&edwards25519.Point{}).SetBytes(server_pubkey)
auth_bytes, err := base64.StdEncoding.DecodeString(auth_base64)
if err != nil { if err != nil {
return nil, NodeID{}, err return "", err
} }
idx := slices.Index(auth_bytes, ':') ecdh_server, err := ECDH.NewPublicKey(server_point.BytesMontgomery())
if idx == -1 { if err != nil {
return nil, NodeID{}, fmt.Errorf("No colon in auth") return "", err
} }
id_base64 := auth_bytes[:idx]
key_base64 := auth_bytes[idx+1:]
id, err := ParseIDB64(string(id_base64)) secret, err := ecdh_client.ECDH(ecdh_server)
if err != nil { if err != nil {
return nil, NodeID{}, err return "", err
}
if len(secret) != 32 {
return "", fmt.Errorf("ECDH secret not 32 bytes(for AES-256): %d bytes long", len(secret))
} }
key, err := ParseKeyB64(string(key_base64)) block, err := aes.NewCipher(secret)
if err != nil { if err != nil {
return nil, NodeID{}, err return "", err
} }
key_id := KeyID(key.Public().(ed25519.PublicKey)) iv := make([]byte, block.BlockSize())
if key_id != id { iv_len, err := rand.Reader.Read(iv)
return nil, NodeID{}, fmt.Errorf("key_id != id(%s != %s)", key_id, id) if err != nil {
return "", err
} else if iv_len != block.BlockSize() {
return "", fmt.Errorf("Not enough iv bytes read: %d", iv_len)
} }
return key, id, nil var key_encrypted bytes.Buffer
stream := cipher.NewOFB(block, iv)
writer := &cipher.StreamWriter{S: stream, W: &key_encrypted}
bytes_written, err := writer.Write(session_key_private.Seed())
if err != nil {
return "", err
} else if bytes_written != len(ecdh_client.Bytes()) {
return "", fmt.Errorf("wrong number of bytes encrypted %d/%d", bytes_written, len(ecdh_client.Bytes()))
} }
func ParseKeyB64(key_base64 string) (ed25519.PrivateKey, error) { digest := append(session_key_public, token_start_bytes...)
key_bytes, err := base64.StdEncoding.DecodeString(key_base64) signature, err := client_key.Sign(rand.Reader, digest, crypto.Hash(0))
if err != nil {
return "", err
}
start_b64 := base64.StdEncoding.EncodeToString(token_start_bytes)
iv_b64 := base64.StdEncoding.EncodeToString(iv)
encrypted_b64 := base64.StdEncoding.EncodeToString(key_encrypted.Bytes())
key_b64 := base64.StdEncoding.EncodeToString(session_key_public)
sig_b64 := base64.StdEncoding.EncodeToString(signature)
id_b64 := base64.StdEncoding.EncodeToString(client_key.Public().(ed25519.PublicKey))
return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil
}
func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) {
joined_b64, err := base64.StdEncoding.DecodeString(auth_base64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) auth_parts := strings.Split(string(joined_b64), ":")
if len(auth_parts) != 6 {
return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts))
}
id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
key, ok := key_raw.(ed25519.PrivateKey) iv, err := base64.StdEncoding.DecodeString(auth_parts[1])
if ok == false { if err != nil {
return nil, fmt.Errorf("parsed key wrong type: %s", reflect.TypeOf(key_raw)) return nil, err
}
public_key, err := base64.StdEncoding.DecodeString(auth_parts[2])
if err != nil {
return nil, err
}
key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3])
if err != nil {
return nil, err
}
start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4])
if err != nil {
return nil, err
}
signature, err := base64.StdEncoding.DecodeString(auth_parts[5])
if err != nil {
return nil, err
}
var start time.Time
err = start.UnmarshalBinary(start_bytes)
if err != nil {
return nil, err
}
client_id := ed25519.PublicKey(id_bytes)
if err != nil {
return nil, err
}
client_point, err := (&edwards25519.Point{}).SetBytes(public_key)
if err != nil {
return nil, err
}
ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery())
if err != nil {
return nil, err
}
h := sha512.Sum512(server_id.Seed())
ecdh_server, err := ECDH.NewPrivateKey(h[:32])
if err != nil {
return nil, err
} }
return key, nil secret, err := ecdh_server.ECDH(ecdh_client)
if err != nil {
return nil, err
} else if len(secret) != 32 {
return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret))
} }
func ParseIDB64(id_base64 string) (NodeID, error) { block, err := aes.NewCipher(secret)
id_bytes, err := base64.StdEncoding.DecodeString(id_base64)
if err != nil { if err != nil {
return NodeID{}, err return nil, err
} }
auth_id, err := IDFromBytes(id_bytes) encrypted_reader := bytes.NewReader(key_encrypted)
stream := cipher.NewOFB(block, iv)
reader := cipher.StreamReader{S: stream, R: encrypted_reader}
var decrypted_key bytes.Buffer
_, err = io.Copy(&decrypted_key, reader)
if err != nil { if err != nil {
return NodeID{}, err return nil, err
}
session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes())
digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...)
if ed25519.Verify(client_id, digest, signature) == false {
return nil, fmt.Errorf("Failed to verify digest/signature against client_id")
} }
return auth_id, nil return &ClientAuthorization{
AuthInfo: AuthInfo{
Identity: client_id,
Start: start,
Signature: signature,
},
Key: session_key,
}, nil
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) {
@ -438,8 +559,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveCon
GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext), GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext),
NodeCache: map[NodeID]NodeResult{}, NodeCache: map[NodeID]NodeResult{},
Server: server, Server: server,
User: NodeID{}, Authorization: nil,
Key: nil,
}, nil }, nil
} }
@ -453,43 +573,21 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
id_b64, key_b64, ok := r.BasicAuth() auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key)
if ok == false {
ctx.Log.Logf("gql", "GQL_AUTH_BASIC_MISSING")
json.NewEncoder(w).Encode(fmt.Errorf("Failed to get auth headers"))
return
}
auth_id, err := ParseIDB64(id_b64)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err)
json.NewEncoder(w).Encode(fmt.Errorf("Failed to parse auth_id: %s", id_b64)) json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
}
key, err := ParseKeyB64(key_b64)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_KEY_PARSE_ERROR: %s", err)
json.NewEncoder(w).Encode(fmt.Errorf("Failed to parse key: %s", key_b64))
return
}
key_id := KeyID(key.Public().(ed25519.PublicKey))
if auth_id != key_id {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: key_id != auth_id: %s != %s", key_id, auth_id)
json.NewEncoder(w).Encode(fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id))
return return
} }
resolve_context, err := NewResolveContext(ctx, server, gql_ext) resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) json.NewEncoder(w).Encode(GQLUnauthorized(""))
return return
} }
resolve_context.Key = key resolve_context.Authorization = auth
resolve_context.User = key_id
req_ctx := context.Background() req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
@ -634,14 +732,13 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
break break
} }
key, key_id, err := ParseAuthB64(connection_params.Payload.Token) authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key)
if err != nil { if err != nil {
ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err) ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err)
break break
} }
resolve_context.User = key_id resolve_context.Authorization = authorization
resolve_context.Key = key
conn_state = "ready" conn_state = "ready"
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}"))

@ -111,7 +111,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
ctx.Context.Log.Logf("gql", "READ_SIGNAL for %s - %+v", id, read_signal) ctx.Context.Log.Logf("gql", "READ_SIGNAL for %s - %+v", id, read_signal)
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id) msgs = msgs.Add(ctx.Context, id, ctx.Server, ctx.Authorization, read_signal)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID()] = response_chan

@ -166,7 +166,7 @@ func (ext *GQLExtContext) AddSignalMutation(name string, send_id_key string, sig
return nil, err return nil, err
} }
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx.Context, ctx.User, ctx.Key, signal, send_id) msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
response_chan := ctx.Ext.GetResponseChannel(signal.ID()) response_chan := ctx.Ext.GetResponseChannel(signal.ID())
err = ctx.Context.Send(msgs) err = ctx.Context.Send(msgs)

@ -18,6 +18,26 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
func TestGQLAuth(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
listener_1 := NewListenerExt(10)
node_1, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_1)
fatalErr(t, err)
listener_2 := NewListenerExt(10)
node_2, err := NewNode(ctx, nil, BaseNodeType, 10, nil, listener_2)
fatalErr(t, err)
auth_header, err := AuthB64(node_1.Key, node_2.Key.Public().(ed25519.PublicKey))
fatalErr(t, err)
auth, err := ParseAuthB64(auth_header, node_2.Key)
fatalErr(t, err)
ctx.Log.Logf("test", "AUTH: %+v", auth)
}
func TestGQLServer(t *testing.T) { func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gqlws"}) ctx := logTestContext(t, []string{"test", "gqlws"})
@ -196,7 +216,7 @@ func TestGQLServer(t *testing.T) {
ctx.Log.Logf("test", "SUB: %s", resp[:n]) ctx.Log.Logf("test", "SUB: %s", resp[:n])
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStatusSignal(gql.ID, Changes{"test_status"}), gql.ID) msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, Changes{"test_status"}))
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
@ -210,7 +230,7 @@ func TestGQLServer(t *testing.T) {
SubGQL(sub_1) SubGQL(sub_1)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal())
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {
@ -241,7 +261,7 @@ func TestGQLDB(t *testing.T) {
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStopSignal(), gql.ID) msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStopSignal())
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {
@ -256,7 +276,7 @@ func TestGQLDB(t *testing.T) {
listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType) listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType)
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{} msgs = Messages{}
msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID) msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded, nil, NewStopSignal())
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, func(sig *StoppedSignal) bool {

@ -65,19 +65,19 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig
switch sig := signal.(type) { switch sig := signal.(type) {
case *AddMemberSignal: case *AddMemberSignal:
if slices.Contains(ext.Members, sig.MemberID) == true { if slices.Contains(ext.Members, sig.MemberID) == true {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "already_member"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
} else { } else {
ext.Members = append(ext.Members, sig.MemberID) ext.Members = append(ext.Members, sig.MemberID)
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes = changes.Add("member_added") changes = changes.Add("member_added")
} }
case *RemoveMemberSignal: case *RemoveMemberSignal:
idx := slices.Index(ext.Members, sig.MemberID) idx := slices.Index(ext.Members, sig.MemberID)
if idx == -1 { if idx == -1 {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "not_member"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
} else { } else {
ext.Members = slices.Delete(ext.Members, idx, idx+1) ext.Members = slices.Delete(ext.Members, idx, idx+1)
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes = changes.Add("member_removed") changes = changes.Add("member_removed")
} }
} }

@ -16,7 +16,7 @@ func TestGroupAdd(t *testing.T) {
add_member_signal := NewAddMemberSignal(user_id) add_member_signal := NewAddMemberSignal(user_id)
messages := Messages{} messages := Messages{}
messages = messages.Add(ctx, group.ID, group.Key, add_member_signal, group.ID) messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
fatalErr(t, ctx.Send(messages)) fatalErr(t, ctx.Send(messages))
_, err = WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id) _, err = WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
@ -27,7 +27,7 @@ func TestGroupAdd(t *testing.T) {
}) })
messages = Messages{} messages = Messages{}
messages = messages.Add(ctx, group.ID, group.Key, read_signal, group.ID) messages = messages.Add(ctx, group.ID, group, nil, read_signal)
fatalErr(t, ctx.Send(messages)) fatalErr(t, ctx.Send(messages))
response, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id) response, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)

@ -53,14 +53,14 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{} messages := Messages{}
signal := NewLockSignal("unlock") signal := NewLockSignal("unlock")
messages = messages.Add(ctx, node.ID, node.Key, signal, node.ID) messages = messages.Add(ctx, node.ID, node, nil, signal)
return signal.ID(), ctx.Send(messages) return signal.ID(), ctx.Send(messages)
} }
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{} messages := Messages{}
signal := NewLockSignal("lock") signal := NewLockSignal("lock")
messages = messages.Add(ctx, node.ID, node.Key, signal, node.ID) messages = messages.Add(ctx, node.ID, node, nil, signal)
return signal.ID(), ctx.Send(messages) return signal.ID(), ctx.Send(messages)
} }
@ -86,7 +86,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI
ext.Requirements[id] = req_info ext.Requirements[id] = req_info
ctx.Log.Logf("lockable", "SENT_ABORT_UNLOCK: %s to %s", lock_signal.ID(), id) ctx.Log.Logf("lockable", "SENT_ABORT_UNLOCK: %s to %s", lock_signal.ID(), id)
messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) messages = messages.Add(ctx, id, node, nil, lock_signal)
} }
} }
} }
@ -106,7 +106,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID
case "add": case "add":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == true { if exists == true {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement"))
} else { } else {
if ext.Requirements == nil { if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqInfo{} ext.Requirements = map[NodeID]ReqInfo{}
@ -116,22 +116,22 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID
uuid.UUID{}, uuid.UUID{},
} }
changes = changes.Add("requirement_added") changes = changes.Add("requirement_added")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), source) messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID()))
} }
case "remove": case "remove":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == false { if exists == false {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "can't link: not_requirement"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement"))
} else { } else {
delete(ext.Requirements, signal.NodeID) delete(ext.Requirements, signal.NodeID)
changes = changes.Add("requirement_removed") changes = changes.Add("requirement_removed")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), source) messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID()))
} }
default: default:
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "unknown_action"))
} }
} else { } else {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked"))
} }
return messages, changes return messages, changes
} }
@ -168,7 +168,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.State = Locked ext.State = Locked
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
changes = changes.Add("locked") changes = changes.Add("locked")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), *ext.Owner) messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID))
} else { } else {
changes = changes.Add("partial_lock") changes = changes.Add("partial_lock")
ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
@ -178,7 +178,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
info.State = Unlocking info.State = Unlocking
info.MsgID = lock_signal.ID() info.MsgID = lock_signal.ID()
ext.Requirements[source] = info ext.Requirements[source] = info
messages = messages.Add(ctx, node.ID, node.Key, lock_signal, source) messages = messages.Add(ctx, source, node, nil, lock_signal)
} }
} else if info.State == Unlocking { } else if info.State == Unlocking {
info.State = Unlocked info.State = Unlocked
@ -202,10 +202,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
ext.ReqID = nil ext.ReqID = nil
changes = changes.Add("unlocked") changes = changes.Add("unlocked")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), previous_owner) messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID))
} else if old_state == AbortingLock { } else if old_state == AbortingLock {
changes = changes.Add("lock_aborted") changes = changes.Add("lock_aborted")
messages = messages.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner) messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked"))
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
} }
} else { } else {
@ -232,7 +232,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = &new_owner ext.PendingOwner = &new_owner
ext.Owner = &new_owner ext.Owner = &new_owner
changes = changes.Add("locked") changes = changes.Add("locked")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID()))
} else { } else {
ext.State = Locking ext.State = Locking
id := signal.ID() id := signal.ID()
@ -249,11 +249,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
info.State = Locking info.State = Locking
info.MsgID = lock_signal.ID() info.MsgID = lock_signal.ID()
ext.Requirements[id] = info ext.Requirements[id] = info
messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) messages = messages.Add(ctx, id, node, nil, lock_signal)
} }
} }
} else { } else {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked"))
} }
case "unlock": case "unlock":
if ext.State == Locked { if ext.State == Locked {
@ -263,7 +263,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = nil ext.PendingOwner = nil
ext.Owner = nil ext.Owner = nil
changes = changes.Add("unlocked") changes = changes.Add("unlocked")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID()))
} else if source == *ext.Owner { } else if source == *ext.Owner {
ext.State = Unlocking ext.State = Unlocking
id := signal.ID() id := signal.ID()
@ -279,11 +279,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
info.State = Unlocking info.State = Unlocking
info.MsgID = lock_signal.ID() info.MsgID = lock_signal.ID()
ext.Requirements[id] = info ext.Requirements[id] = info
messages = messages.Add(ctx, node.ID, node.Key, lock_signal, id) messages = messages.Add(ctx, id, node, nil, lock_signal)
} }
} }
} else { } else {
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_locked"))
} }
default: default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
@ -301,13 +301,13 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal
case Up: case Up:
if ext.Owner != nil { if ext.Owner != nil {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
messages = messages.Add(ctx, node.ID, node.Key, signal, *ext.Owner) messages = messages.Add(ctx, *ext.Owner, node, nil, signal)
} }
} }
case Down: case Down:
for requirement := range(ext.Requirements) { for requirement := range(ext.Requirements) {
messages = messages.Add(ctx, node.ID, node.Key, signal, requirement) messages = messages.Add(ctx, requirement, node, nil, signal)
} }
case Direct: case Direct:

@ -44,7 +44,7 @@ func TestLink(t *testing.T) {
msgs := Messages{} msgs := Messages{}
link_signal := NewLinkSignal("add", l2.ID) link_signal := NewLinkSignal("add", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1.Key, link_signal, l1.ID) msgs = msgs.Add(ctx, l1.ID, l1, nil, link_signal)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
@ -60,7 +60,7 @@ func TestLink(t *testing.T) {
msgs = Messages{} msgs = Messages{}
unlink_signal := NewLinkSignal("remove", l2.ID) unlink_signal := NewLinkSignal("remove", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1.Key, unlink_signal, l1.ID) msgs = msgs.Add(ctx, l1.ID, l1, nil, unlink_signal)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)

@ -282,29 +282,27 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err) ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err)
continue continue
} }
src_id_ser, err := msg.Source.MarshalBinary() src_id_ser, err := KeyID(msg.Source).MarshalBinary()
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err) ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err)
continue continue
} }
sig_data := append(dst_id_ser, src_id_ser...) sig_data := append(dst_id_ser, src_id_ser...)
sig_data = append(sig_data, ser...) sig_data = append(sig_data, ser...)
validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) validated := ed25519.Verify(msg.Source, sig_data, msg.Signature)
if validated == false { if validated == false {
println(fmt.Sprintf("SIGNAL: %s", msg.Signal))
println(fmt.Sprintf("VERIFY_DIGEST: %+v", sig_data))
ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg) ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg)
continue continue
} }
princ_id := KeyID(msg.Principal) princ_id := KeyID(msg.Source)
if princ_id != node.ID { if princ_id != node.ID {
pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission()) pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission())
if resp == Deny { if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal) ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal)
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) msgs = msgs.Add(ctx, KeyID(msg.Source), node, nil, NewErrorSignal(msg.Signal.ID(), "acl denied"))
ctx.Send(msgs) ctx.Send(msgs)
continue continue
} else if resp == Pending { } else if resp == Pending {
@ -327,7 +325,7 @@ func nodeLoop(ctx *Context, node *Node) error {
Principal: princ_id, Principal: princ_id,
Responses: []ResponseSignal{}, Responses: []ResponseSignal{},
Signal: msg.Signal, Signal: msg.Signal,
Source: msg.Source, Source: KeyID(msg.Source),
} }
ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs) ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs)
ctx.Send(msgs) ctx.Send(msgs)
@ -340,7 +338,7 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
signal = msg.Signal signal = msg.Signal
source = msg.Source source = KeyID(msg.Source)
case <-node.TimeoutChan: case <-node.TimeoutChan:
signal = node.NextSignal.Signal signal = node.NextSignal.Signal
@ -408,7 +406,7 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response // Send the denied response
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) msgs = msgs.Add(ctx, req_info.Source, node, nil, NewErrorSignal(req_info.Signal.ID(), "acl_denied"))
err := ctx.Send(msgs) err := ctx.Send(msgs)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err) ctx.Log.Logf("signal", "SEND_ERR: %s", err)
@ -434,7 +432,7 @@ func nodeLoop(ctx *Context, node *Node) error {
node.Process(ctx, source, NewStoppedSignal(sig, node.ID)) node.Process(ctx, source, NewStoppedSignal(sig, node.ID))
} else { } else {
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, node.ID, node.Key, NewStoppedSignal(sig, node.ID), source) msgs = msgs.Add(ctx, node.ID, node, nil, NewStoppedSignal(sig, node.ID))
ctx.Send(msgs) ctx.Send(msgs)
} }
run = false run = false
@ -442,7 +440,7 @@ func nodeLoop(ctx *Context, node *Node) error {
case *ReadSignal: case *ReadSignal:
result := node.ReadFields(ctx, sig.Extensions) result := node.ReadFields(ctx, sig.Extensions)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID(), node.ID, node.Type, result), source) msgs = msgs.Add(ctx, source, node, nil, NewReadResultSignal(sig.ID(), node.ID, node.Type, result))
ctx.Send(msgs) ctx.Send(msgs)
default: default:
@ -460,17 +458,52 @@ func nodeLoop(ctx *Context, node *Node) error {
return nil return nil
} }
type AuthInfo struct {
// The Node that issued the authorization
Identity ed25519.PublicKey
// Time the authorization was generated
Start time.Time
// Signature of Start + Principal with Identity private key
Signature []byte
}
type AuthorizationToken struct {
AuthInfo
// The private key generated by the client, encrypted with the servers public key
KeyEncrypted []byte
}
type ClientAuthorization struct {
AuthInfo
// The private key generated by the client
Key ed25519.PrivateKey
}
// Authorization structs can be passed in a message that originated from a different node than the sender
type Authorization struct {
AuthInfo
// The public key generated for this authorization
Key ed25519.PublicKey
}
type Message struct { type Message struct {
Source NodeID
Dest NodeID Dest NodeID
Principal ed25519.PublicKey Source ed25519.PublicKey
Authorization *Authorization
Signal Signal Signal Signal
Signature []byte Signature []byte
} }
type Messages []*Message type Messages []*Message
func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
msg, err := NewMessage(ctx, dest, source, principal, signal) msg, err := NewMessage(ctx, dest, source, authorization, signal)
if err != nil { if err != nil {
panic(err) panic(err)
} else { } else {
@ -479,7 +512,7 @@ func (msgs Messages) Add(ctx *Context, source NodeID, principal ed25519.PrivateK
return msgs return msgs
} }
func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) {
signal_ser, err := SerializeAny(ctx, signal) signal_ser, err := SerializeAny(ctx, signal)
if err != nil { if err != nil {
return nil, err return nil, err
@ -494,30 +527,39 @@ func NewMessage(ctx *Context, dest NodeID, source NodeID, principal ed25519.Priv
if err != nil { if err != nil {
return nil, err return nil, err
} }
source_ser, err := source.MarshalBinary() source_ser, err := source.ID.MarshalBinary()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sig_data := append(dest_ser, source_ser...) sig_data := append(dest_ser, source_ser...)
sig_data = append(sig_data, ser...) sig_data = append(sig_data, ser...)
var message_auth *Authorization = nil
if authorization != nil {
sig_data = append(sig_data, authorization.Signature...)
message_auth = &Authorization{
authorization.AuthInfo,
authorization.Key.Public().(ed25519.PublicKey),
}
}
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Message{ return &Message{
Dest: dest, Dest: dest,
Source: source, Source: source.Key.Public().(ed25519.PublicKey),
Principal: principal.Public().(ed25519.PublicKey), Authorization: message_auth,
Signal: signal, Signal: signal,
Signature: sig, Signature: sig,
}, nil }, nil
} }
func (node *Node) Stop(ctx *Context) error { func (node *Node) Stop(ctx *Context) error {
if node.Active.Load() { if node.Active.Load() {
msg, err := NewMessage(ctx, node.ID, node.ID, node.Key, NewStopSignal()) msg, err := NewMessage(ctx, node.ID, node, nil, NewStopSignal())
if err != nil { if err != nil {
return err return err
} }

@ -23,7 +23,7 @@ func TestNodeDB(t *testing.T) {
}) })
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, node.ID, node.Key, NewStopSignal(), node.ID) msgs = msgs.Add(ctx, node.ID, node, nil, NewStopSignal())
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
@ -71,7 +71,7 @@ func TestNodeRead(t *testing.T) {
GroupExtType: {"members"}, GroupExtType: {"members"},
}) })
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(ctx, n2.ID, n2.Key, read_sig, n1.ID) msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)

@ -147,9 +147,9 @@ func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tr
} }
} }
} else { } else {
msgs = msgs.Add(ctx, node.ID, node.Key, NewReadSignal(map[ExtType][]string{ msgs = msgs.Add(ctx, id, node, nil, NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: []string{"members"},
}), id) }))
} }
} }
return msgs, Pending return msgs, Pending

@ -7,7 +7,7 @@ import (
) )
func TestSerializeBasic(t *testing.T) { func TestSerializeBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize"}) ctx := logTestContext(t, []string{"test"})
testSerializeComparable[string](t, ctx, "test") testSerializeComparable[string](t, ctx, "test")
testSerializeComparable[bool](t, ctx, true) testSerializeComparable[bool](t, ctx, true)
testSerializeComparable[float32](t, ctx, 0.05) testSerializeComparable[float32](t, ctx, 0.05)