Moved test 'WaitForX' functions to a generic function in signal.go that can be used to wait for arbitrary signals

gql_cataclysm
noah metz 2023-07-28 19:32:27 -06:00
parent 5fb1cb6d17
commit f87571edcf
9 changed files with 423 additions and 386 deletions

@ -8,6 +8,7 @@ import (
"runtime"
"crypto/sha512"
"crypto/elliptic"
"crypto/ecdh"
"encoding/binary"
)
@ -80,6 +81,8 @@ type Context struct {
Types map[uint64]*NodeInfo
// Curve used for signature operations
ECDSA elliptic.Curve
// Curve used for ecdh operations
ECDH ecdh.Curve
// Routing map to all the nodes local to this context
NodesLock sync.RWMutex
Nodes map[NodeID]*Node
@ -197,6 +200,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
Extensions: map[uint64]ExtensionInfo{},
Types: map[uint64]*NodeInfo{},
Nodes: map[NodeID]*Node{},
ECDH: ecdh.P256(),
ECDSA: elliptic.P256(),
}

181
gql.go

@ -10,7 +10,6 @@ import (
"github.com/graphql-go/graphql/language/ast"
"context"
"encoding/json"
"encoding/base64"
"io"
"reflect"
"fmt"
@ -21,7 +20,6 @@ import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/sha512"
"crypto/rand"
"crypto/x509"
"crypto/tls"
@ -30,173 +28,6 @@ import (
"encoding/pem"
)
type AuthReqJSON struct {
Time time.Time `json:"time"`
Pubkey []byte `json:"pubkey"`
ECDHPubkey []byte `json:"ecdh_client"`
Signature []byte `json:"signature"`
}
func NewAuthReqJSON(curve ecdh.Curve, id *ecdsa.PrivateKey) (AuthReqJSON, *ecdh.PrivateKey, error) {
ec_key, err := curve.GenerateKey(rand.Reader)
if err != nil {
return AuthReqJSON{}, nil, err
}
now := time.Now()
time_bytes, err := now.MarshalJSON()
if err != nil {
return AuthReqJSON{}, nil, err
}
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig_hash := sha512.Sum512(sig_data)
sig, err := ecdsa.SignASN1(rand.Reader, id, sig_hash[:])
id_ecdh, err := id.ECDH()
if err != nil {
return AuthReqJSON{}, nil, err
}
return AuthReqJSON{
Time: now,
Pubkey: id_ecdh.PublicKey().Bytes(),
ECDHPubkey: ec_key.PublicKey().Bytes(),
Signature: sig,
}, ec_key, nil
}
type AuthRespJSON struct {
Granted time.Time `json:"granted"`
ECDHPubkey []byte `json:"echd_server"`
Signature []byte `json:"signature"`
}
func NewAuthRespJSON(gql_ext *GQLExt, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
// Check if req.Time is within +- 1 second of now
now := time.Now()
earliest := now.Add(-1 * time.Second)
latest := now.Add(1 * time.Second)
// If req.Time is before the earliest acceptable time, or after the latest acceptible time
if req.Time.Compare(earliest) == -1 {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time)
} else if req.Time.Compare(latest) == 1 {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
}
x, y := elliptic.Unmarshal(gql_ext.Key.Curve, req.Pubkey)
if x == nil {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
}
remote, err := gql_ext.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
// Verify the signature
time_bytes, _ := req.Time.MarshalJSON()
sig_data := append(req.ECDHPubkey, time_bytes...)
sig_hash := sha512.Sum512(sig_data)
remote_key := &ecdsa.PublicKey{
Curve: gql_ext.Key.Curve,
X: x,
Y: y,
}
verified := ecdsa.VerifyASN1(
remote_key,
sig_hash[:],
req.Signature,
)
if verified == false {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
}
ec_key, err := gql_ext.ECDH.GenerateKey(rand.Reader)
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
ec_key_pub := ec_key.PublicKey().Bytes()
granted := time.Now()
time_ser, _ := granted.MarshalJSON()
resp_sig_data := append(ec_key_pub, time_ser...)
resp_sig_hash := sha512.Sum512(resp_sig_data)
resp_sig, err := ecdsa.SignASN1(rand.Reader, gql_ext.Key, resp_sig_hash[:])
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
shared_secret, err := ec_key.ECDH(remote)
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
return AuthRespJSON{
Granted: granted,
ECDHPubkey: ec_key_pub,
Signature: resp_sig,
}, remote_key, shared_secret, nil
}
func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve ecdh.Curve, ec_key *ecdh.PrivateKey) ([]byte, error) {
remote, err := ecdh_curve.NewPublicKey(resp.ECDHPubkey)
if err != nil {
return nil, err
}
shared_secret, err := ec_key.ECDH(remote)
if err != nil {
return nil, err
}
return shared_secret, nil
}
func AuthHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
str, err := io.ReadAll(r.Body)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_READ_ERR: %e", err)
return
}
var req AuthReqJSON
err = json.Unmarshal([]byte(str), &req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_UNMARHSHAL_ERR: %e", err)
return
}
resp, _, _, err := NewAuthRespJSON(gql_ext, req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err)
return
}
ser, err := json.Marshal(resp)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_MARSHAL_ERR: %e", err)
return
}
wrote, err := w.Write(ser)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_ERR: %e", err)
return
} else if wrote != len(ser) {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_BAD_LENGTH: %d/%d", wrote, len(ser))
return
}
}
}
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
graphiql_string := fmt.Sprintf(`
@ -340,7 +171,7 @@ type ResolveContext struct {
}
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, password, ok := r.BasicAuth()
username, _, ok := r.BasicAuth()
if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
}
@ -355,15 +186,6 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id)
}
user_ext, err := GetExt[*ECDHExt](user)
if err != nil {
return nil, err
}
if base64.StdEncoding.EncodeToString(user_ext.Shared) != password {
return nil, fmt.Errorf("GQL_AUTH_FAIL")
}
return &ResolveContext{
Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
@ -942,7 +764,6 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_
func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error {
mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, node, gql_ext))
mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext))

@ -34,7 +34,10 @@ func TestGQLDB(t * testing.T) {
err = ctx.Send(gql.ID, gql.ID, StopSignal)
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener")
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStateSignal) bool {
return sig.State == "stopped" && sig.ID == gql.ID
})
fatalErr(t, err)
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
@ -49,7 +52,10 @@ func TestGQLDB(t * testing.T) {
fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal)
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStateSignal) bool {
return sig.State == "stopped" && sig.ID == gql_loaded.ID
})
fatalErr(t, err)
}

@ -2,111 +2,10 @@ package graphvent
import (
"testing"
"fmt"
"time"
"runtime/pprof"
"runtime/debug"
"os"
badger "github.com/dgraph-io/badger/v3"
)
type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond
func (t *GraphTester) WaitForReadResult(ctx *Context, listener *ListenerExt, timeout time.Duration, str string) map[ExtType]map[string]interface{} {
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener.Chan:
ctx.Log.Logf("test", "SIGNAL %+v", signal)
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == ReadResultSignalType {
result_signal, ok := signal.(ReadResultSignal)
if ok == false {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_BAD_CAST: %+v", signal)
t.Fatal(str)
}
return result_signal.Extensions
}
case <-timeout_channel:
ctx.Log.Logf("test", "SIGNAL_CHANNEL_TIMEOUT: %+v", listener)
t.Fatal(str)
}
}
return nil
}
func (t *GraphTester) WaitForState(ctx * Context, listener *ListenerExt, stype SignalType, state string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener.Chan:
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == stype {
sig, ok := signal.(StateSignal)
if ok == true {
ctx.Log.Logf("test", "%s state received: %s", stype, sig.State)
if sig.State == state {
return signal
}
}
}
case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(str)
return nil
}
}
return nil
}
func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener.Chan:
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == StatusSignalType {
sig, ok := signal.(StatusSignal)
if ok == true {
ctx.Log.Logf("test", "Status received: %s", sig.Status)
if sig.Status == status {
return signal
}
} else {
ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal)
}
}
case <-timeout_channel:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(str)
return nil
}
}
return nil
}
func (t * GraphTester) CheckForNone(listener *ListenerExt, str string) {
timeout := time.After(listner_timeout)
select {
case sig := <- listener.Chan:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal(fmt.Sprintf("%s : %+v", str, sig))
case <-timeout:
}
}
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {

@ -420,8 +420,6 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
ctx.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
switch signal.Direction() {
case Up:
owner_sent := false

@ -3,7 +3,6 @@ package graphvent
import (
"testing"
"time"
"fmt"
)
const TestLockableType = NodeType("TEST_LOCKABLE")
@ -41,18 +40,31 @@ func TestLink(t *testing.T) {
err := LinkRequirement(ctx, l1.ID, l2.ID)
fatalErr(t, err)
(*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*10, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l2_listener, LinkSignalType, "linked_as_dep", time.Millisecond*10, "No req_linked")
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool {
return sig.State == "linked_as_req"
})
fatalErr(t, err)
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool {
return sig.State == "linked_as_dep"
})
fatalErr(t, err)
err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID))
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, l1_listener, "TEST", time.Millisecond*10, "No TEST on l1")
(*GraphTester)(t).WaitForStatus(ctx, l2_listener, "TEST", time.Millisecond*10, "No TEST on l2")
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig IDStateSignal) bool {
return sig.State == "TEST"
})
fatalErr(t, err)
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig IDStateSignal) bool {
return sig.State == "TEST"
})
fatalErr(t, err)
}
func TestLink10K(t *testing.T) {
ctx := lockableTestContext(t, []string{"test"})
ctx := lockableTestContext(t, []string{})
NewLockable := func()(*Node) {
l := NewNode(ctx, nil, TestLockableType, 10, nil,
@ -82,8 +94,11 @@ func TestLink10K(t *testing.T) {
ctx.Log.Logf("test", "CREATED_10K")
for i, _ := range(lockables) {
(*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*1000, fmt.Sprintf("No linked_as_req for %d", i))
for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool {
return sig.State == "linked_as_req"
})
fatalErr(t, err)
}
ctx.Log.Logf("test", "LINKED_10K")
@ -129,23 +144,44 @@ func TestLock(t *testing.T) {
err = LinkRequirement(ctx, l0.ID, l5.ID)
fatalErr(t, err)
(*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
linked_as_req := func(sig StateSignal) bool {
return sig.State == "linked_as_req"
}
(*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
(*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req")
locked := func(sig StateSignal) bool {
return sig.State == "locked"
}
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
err = LockLockable(ctx, l1)
fatalErr(t, err)
(*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked")
(*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked")
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
err = UnlockLockable(ctx, l1)
fatalErr(t, err)

@ -21,9 +21,9 @@ func TestNodeDB(t *testing.T) {
}
func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{"test", "read", "signal", "policy", "node", "loop"})
ctx := logTestContext(t, []string{})
node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType})
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType})
fatalErr(t, err)
n1_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader)
@ -41,17 +41,54 @@ func TestNodeRead(t *testing.T) {
n1_id: Actions{MakeAction(ReadResultSignalType, "+")},
})
n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), n2_listener)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener)
n1_policy := NewPerNodePolicy(map[NodeID]Actions{
n2_id: Actions{MakeAction(ReadSignalType, "+")},
})
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil))
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil), NewECDHExt())
ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"},
}))
res := (*GraphTester)(t).WaitForReadResult(ctx, n2_listener, 10*time.Millisecond, "No read_result")
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig ReadResultSignal) bool {
return true
})
fatalErr(t, err)
ctx.Log.Logf("test", "READ_RESULT: %+v", res)
}
func TestECDH(t *testing.T) {
ctx := logTestContext(t, []string{"test", "ecdh", "policy"})
node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, ECDHExtType})
fatalErr(t, err)
n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+")})
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt())
ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID)
ecdh_req, n1_ec, err := NewECDHReqSignal(ctx, n1)
ecdh_ext, err := GetExt[*ECDHExt](n1)
fatalErr(t, err)
ecdh_ext.ECDHStates[n2.ID] = ECDHState{
ECKey: n1_ec,
SharedSecret: nil,
}
fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, n2.ID, ecdh_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig ECDHSignal) bool {
return sig.State == "resp"
})
fatalErr(t, err)
}

@ -1,21 +1,26 @@
package graphvent
import (
"time"
"fmt"
"encoding/json"
)
const (
StopSignalType = SignalType("STOP")
StatusSignalType = SignalType("STATUS")
LinkSignalType = SignalType("LINK")
LockSignalType = SignalType("LOCK")
ReadSignalType = SignalType("READ")
ReadResultSignalType = SignalType("READ_RESULT")
LinkStartSignalType = SignalType("LINK_START")
"crypto/sha512"
"crypto/ecdsa"
"crypto/ecdh"
"crypto/rand"
)
type SignalDirection int
const (
StopSignalType SignalType = "STOP"
StatusSignalType = "STATUS"
LinkSignalType = "LINK"
LockSignalType = "LOCK"
ReadSignalType = "READ"
ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH"
Up SignalDirection = iota
Down
Direct
@ -29,10 +34,35 @@ func (signal_type SignalType) String() string {
type Signal interface {
Serializable[SignalType]
Direction() SignalDirection
MarshalJSON() ([]byte, error)
Permission() Action
}
func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) {
var zero S
timeout_channel := time.After(timeout)
for true {
select {
case signal := <- listener.Chan:
if signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type)
}
if signal.Type() == signal_type {
sig, ok := signal.(S)
if ok == true {
ctx.Log.Logf("test", "received: %+v", sig)
if check(sig) == true {
return sig, nil
}
}
}
case <-timeout_channel:
return zero, fmt.Errorf("LISTENER_TIMEOUT: %s", signal_type)
}
}
return zero, fmt.Errorf("LOOP_ENDED")
}
type BaseSignal struct {
SignalDirection SignalDirection `json:"direction"`
SignalType SignalType `json:"type"`
@ -50,7 +80,7 @@ func (signal BaseSignal) Direction() SignalDirection {
return signal.SignalDirection
}
func (signal BaseSignal) MarshalJSON() ([]byte, error) {
func (signal *BaseSignal) MarshalJSON() ([]byte, error) {
return json.Marshal(signal)
}
@ -100,12 +130,18 @@ func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) I
}
}
type StatusSignal struct {
IDSignal
Status string `json:"status"`
type StateSignal struct {
BaseSignal
State string `json:"state"`
}
func (signal StatusSignal) String() string {
type IDStateSignal struct {
BaseSignal
ID NodeID `json:"id"`
State string `json:"status"`
}
func (signal IDStateSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
@ -113,18 +149,14 @@ func (signal StatusSignal) String() string {
return string(ser)
}
func NewStatusSignal(status string, source NodeID) StatusSignal {
return StatusSignal{
IDSignal: NewIDSignal(StatusSignalType, Up, source),
Status: status,
func NewStatusSignal(status string, source NodeID) IDStateSignal {
return IDStateSignal{
BaseSignal: NewUpSignal(StatusSignalType),
ID: source,
State: status,
}
}
type StateSignal struct {
BaseSignal
State string `json:"state"`
}
func (signal StateSignal) Serialize() ([]byte, error) {
return json.MarshalIndent(signal, "", " ")
}
@ -188,4 +220,118 @@ func NewReadResultSignal(exts map[ExtType]map[string]interface{}) ReadResultSign
}
}
type ECDHSignal struct {
StateSignal
Time time.Time
ECDSA *ecdsa.PublicKey
ECDH *ecdh.PublicKey
Signature []byte
}
func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) {
time_bytes, err := now.MarshalJSON()
if err != nil {
return nil, err
}
sig_data := append(ec_key.Bytes(), time_bytes...)
sig_hash := sha512.Sum512(sig_data)
return sig_hash[:], nil
}
func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) {
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
}
now := time.Now()
sig_hash, err := keyHash(now, ec_key.PublicKey())
if err != nil {
return ECDHSignal{}, nil, err
}
sig, err := ecdsa.SignASN1(rand.Reader, node.Key, sig_hash)
if err != nil {
return ECDHSignal{}, nil, err
}
return ECDHSignal{
StateSignal: StateSignal{
BaseSignal: NewDirectSignal(ECDHSignalType),
State: "req",
},
Time: now,
ECDSA: &node.Key.PublicKey,
ECDH: ec_key.PublicKey(),
Signature: sig,
}, ec_key, nil
}
const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now()
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
if err != nil {
return ECDHSignal{}, nil, err
}
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
}
shared_secret, err := ec_key.ECDH(req.ECDH)
if err != nil {
return ECDHSignal{}, nil, err
}
key_hash, err := keyHash(now, ec_key.PublicKey())
if err != nil {
return ECDHSignal{}, nil, err
}
sig, err := ecdsa.SignASN1(rand.Reader, node.Key, key_hash)
if err != nil {
return ECDHSignal{}, nil, err
}
return ECDHSignal{
StateSignal: StateSignal{
BaseSignal: NewDirectSignal(ECDHSignalType),
State: "resp",
},
Time: now,
ECDSA: &node.Key.PublicKey,
ECDH: ec_key.PublicKey(),
Signature: sig,
}, shared_secret, nil
}
func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error {
earliest := now.Add(-window)
latest := now.Add(window)
if sig.Time.Compare(earliest) == -1 {
return fmt.Errorf("TIME_TOO_LATE: %+v", sig.Time)
} else if sig.Time.Compare(latest) == 1 {
return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time)
}
sig_hash, err := keyHash(sig.Time, sig.ECDH)
if err != nil {
return err
}
verified := ecdsa.VerifyASN1(sig.ECDSA, sig_hash, sig.Signature)
if verified == false {
return fmt.Errorf("VERIFY_FAIL")
}
return nil
}

@ -1,17 +1,89 @@
package graphvent
import (
"time"
"fmt"
"time"
"encoding/json"
"crypto/ecdsa"
"crypto/x509"
"crypto/ecdh"
)
type ECDHState struct {
ECKey *ecdh.PrivateKey
SharedSecret []byte
}
type ECDHStateJSON struct {
ECKey []byte `json:"ec_key"`
SharedSecret []byte `json:"shared_secret"`
}
func (state *ECDHState) MarshalJSON() ([]byte, error) {
var key_bytes []byte
var err error
if state.ECKey != nil {
key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey)
if err != nil {
return nil, err
}
}
return json.Marshal(&ECDHStateJSON{
ECKey: key_bytes,
SharedSecret: state.SharedSecret,
})
}
func (state *ECDHState) UnmarshalJSON(data []byte) error {
var j ECDHStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
state.SharedSecret = j.SharedSecret
if len(j.ECKey) == 0 {
state.ECKey = nil
} else {
tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey)
if err != nil {
return err
}
ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey)
if ok == false {
return fmt.Errorf("Parsed wrong key type from DB for ECDHState")
}
state.ECKey, err = ecdsa_key.ECDH()
if err != nil {
return err
}
}
return nil
}
type ECDHMap map[NodeID]ECDHState
func (m ECDHMap) MarshalJSON() ([]byte, error) {
tmp := map[string]ECDHState{}
for id, state := range(m) {
tmp[id.String()] = state
}
return json.Marshal(tmp)
}
type ECDHExt struct {
Granted time.Time
Pubkey *ecdsa.PublicKey
Shared []byte
ECDHStates ECDHMap
}
func NewECDHExt() *ECDHExt {
return &ECDHExt{
ECDHStates: ECDHMap{},
}
}
func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} {
@ -25,26 +97,72 @@ func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)
func (ext *ECDHExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{
"granted": func(ext *ECDHExt) interface{} {
return ext.Granted
},
"pubkey": func(ext *ECDHExt) interface{} {
return ext.Pubkey
},
"shared": func(ext *ECDHExt) interface{} {
return ext.Shared
"ecdh_states": func(ext *ECDHExt) interface{} {
return ext.ECDHStates
},
})
}
type ECDHExtJSON struct {
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"`
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) {
ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %+v", source, node, signal)
switch signal.State {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, resp)
} else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHSignalType), "no_req"})
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
shared_secret, err := state.ECKey.ECDH(signal.ECDH)
if err == nil {
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State)
}
}
func (ext *ECDHExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
return
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) {
}
func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
switch signal.Direction() {
case Direct:
switch signal.Type() {
case ECDHSignalType:
switch ecdh_signal := signal.(type) {
case ECDHSignal:
ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
case StateSignal:
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
default:
ctx.Log.Logf("ecdh", "BAD_SIGNAL_CAST: %+v", signal)
}
default:
}
default:
}
}
func (ext *ECDHExt) Type() ExtType {
@ -52,45 +170,17 @@ func (ext *ECDHExt) Type() ExtType {
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
if err != nil {
return nil, err
}
return json.MarshalIndent(&ECDHExtJSON{
Granted: ext.Granted,
Pubkey: pubkey,
Shared: ext.Shared,
}, "", " ")
return json.MarshalIndent(ext, "", " ")
}
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var j ECDHExtJSON
err := json.Unmarshal(data, &j)
var ext ECDHExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
if err != nil {
return nil, err
}
var pubkey *ecdsa.PublicKey
switch pub.(type) {
case *ecdsa.PublicKey:
pubkey = pub.(*ecdsa.PublicKey)
default:
return nil, fmt.Errorf("Invalid key type: %+v", pub)
}
extension := ECDHExt{
Granted: j.Granted,
Pubkey: pubkey,
Shared: j.Shared,
}
return &extension, nil
return &ext, nil
}
type GroupExt struct {