From f87571edcf6a717227462fef7e33d723407fee6b Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 28 Jul 2023 19:32:27 -0600 Subject: [PATCH] Moved test 'WaitForX' functions to a generic function in signal.go that can be used to wait for arbitrary signals --- context.go | 4 + gql.go | 181 +------------------------------------------ gql_test.go | 10 ++- graph_test.go | 101 ------------------------ lockable.go | 2 - lockable_test.go | 78 ++++++++++++++----- node_test.go | 47 ++++++++++-- signal.go | 196 +++++++++++++++++++++++++++++++++++++++++------ user.go | 190 +++++++++++++++++++++++++++++++++------------ 9 files changed, 423 insertions(+), 386 deletions(-) diff --git a/context.go b/context.go index 141d0df..b903844 100644 --- a/context.go +++ b/context.go @@ -8,6 +8,7 @@ import ( "runtime" "crypto/sha512" "crypto/elliptic" + "crypto/ecdh" "encoding/binary" ) @@ -80,6 +81,8 @@ type Context struct { Types map[uint64]*NodeInfo // Curve used for signature operations ECDSA elliptic.Curve + // Curve used for ecdh operations + ECDH ecdh.Curve // Routing map to all the nodes local to this context NodesLock sync.RWMutex Nodes map[NodeID]*Node @@ -197,6 +200,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Extensions: map[uint64]ExtensionInfo{}, Types: map[uint64]*NodeInfo{}, Nodes: map[NodeID]*Node{}, + ECDH: ecdh.P256(), ECDSA: elliptic.P256(), } diff --git a/gql.go b/gql.go index 0ba4777..0b0d887 100644 --- a/gql.go +++ b/gql.go @@ -10,7 +10,6 @@ import ( "github.com/graphql-go/graphql/language/ast" "context" "encoding/json" - "encoding/base64" "io" "reflect" "fmt" @@ -21,7 +20,6 @@ import ( "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" - "crypto/sha512" "crypto/rand" "crypto/x509" "crypto/tls" @@ -30,173 +28,6 @@ import ( "encoding/pem" ) -type AuthReqJSON struct { - Time time.Time `json:"time"` - Pubkey []byte `json:"pubkey"` - ECDHPubkey []byte `json:"ecdh_client"` - Signature []byte `json:"signature"` -} - -func NewAuthReqJSON(curve ecdh.Curve, id *ecdsa.PrivateKey) (AuthReqJSON, *ecdh.PrivateKey, error) { - ec_key, err := curve.GenerateKey(rand.Reader) - if err != nil { - return AuthReqJSON{}, nil, err - } - now := time.Now() - time_bytes, err := now.MarshalJSON() - if err != nil { - return AuthReqJSON{}, nil, err - } - sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) - sig_hash := sha512.Sum512(sig_data) - sig, err := ecdsa.SignASN1(rand.Reader, id, sig_hash[:]) - - id_ecdh, err := id.ECDH() - if err != nil { - return AuthReqJSON{}, nil, err - } - - return AuthReqJSON{ - Time: now, - Pubkey: id_ecdh.PublicKey().Bytes(), - ECDHPubkey: ec_key.PublicKey().Bytes(), - Signature: sig, - }, ec_key, nil -} - -type AuthRespJSON struct { - Granted time.Time `json:"granted"` - ECDHPubkey []byte `json:"echd_server"` - Signature []byte `json:"signature"` -} - -func NewAuthRespJSON(gql_ext *GQLExt, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { - // Check if req.Time is within +- 1 second of now - now := time.Now() - earliest := now.Add(-1 * time.Second) - latest := now.Add(1 * time.Second) - // If req.Time is before the earliest acceptable time, or after the latest acceptible time - if req.Time.Compare(earliest) == -1 { - return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time) - } else if req.Time.Compare(latest) == 1 { - return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) - } - - x, y := elliptic.Unmarshal(gql_ext.Key.Curve, req.Pubkey) - if x == nil { - return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) - } - - remote, err := gql_ext.ECDH.NewPublicKey(req.ECDHPubkey) - if err != nil { - return AuthRespJSON{}, nil, nil, err - } - - // Verify the signature - time_bytes, _ := req.Time.MarshalJSON() - sig_data := append(req.ECDHPubkey, time_bytes...) - sig_hash := sha512.Sum512(sig_data) - - remote_key := &ecdsa.PublicKey{ - Curve: gql_ext.Key.Curve, - X: x, - Y: y, - } - - verified := ecdsa.VerifyASN1( - remote_key, - sig_hash[:], - req.Signature, - ) - - if verified == false { - return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) - } - - ec_key, err := gql_ext.ECDH.GenerateKey(rand.Reader) - if err != nil { - return AuthRespJSON{}, nil, nil, err - } - - ec_key_pub := ec_key.PublicKey().Bytes() - - granted := time.Now() - time_ser, _ := granted.MarshalJSON() - resp_sig_data := append(ec_key_pub, time_ser...) - resp_sig_hash := sha512.Sum512(resp_sig_data) - - resp_sig, err := ecdsa.SignASN1(rand.Reader, gql_ext.Key, resp_sig_hash[:]) - if err != nil { - return AuthRespJSON{}, nil, nil, err - } - - shared_secret, err := ec_key.ECDH(remote) - if err != nil { - return AuthRespJSON{}, nil, nil, err - } - - return AuthRespJSON{ - Granted: granted, - ECDHPubkey: ec_key_pub, - Signature: resp_sig, - }, remote_key, shared_secret, nil -} - -func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve ecdh.Curve, ec_key *ecdh.PrivateKey) ([]byte, error) { - remote, err := ecdh_curve.NewPublicKey(resp.ECDHPubkey) - if err != nil { - return nil, err - } - - shared_secret, err := ec_key.ECDH(remote) - if err != nil { - return nil, err - } - - return shared_secret, nil -} - -func AuthHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr) - enableCORS(&w) - - str, err := io.ReadAll(r.Body) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_READ_ERR: %e", err) - return - } - - var req AuthReqJSON - err = json.Unmarshal([]byte(str), &req) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_UNMARHSHAL_ERR: %e", err) - return - } - - resp, _, _, err := NewAuthRespJSON(gql_ext, req) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err) - return - } - - ser, err := json.Marshal(resp) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_RESP_MARSHAL_ERR: %e", err) - return - } - - wrote, err := w.Write(ser) - if err != nil { - ctx.Log.Logf("gql", "GQL_AUTH_RESP_ERR: %e", err) - return - } else if wrote != len(ser) { - ctx.Log.Logf("gql", "GQL_AUTH_RESP_BAD_LENGTH: %d/%d", wrote, len(ser)) - return - } - } -} - func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { graphiql_string := fmt.Sprintf(` @@ -340,7 +171,7 @@ type ResolveContext struct { } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { - username, password, ok := r.BasicAuth() + username, _, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") } @@ -355,15 +186,6 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) } - user_ext, err := GetExt[*ECDHExt](user) - if err != nil { - return nil, err - } - - if base64.StdEncoding.EncodeToString(user_ext.Shared) != password { - return nil, fmt.Errorf("GQL_AUTH_FAIL") - } - return &ResolveContext{ Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), @@ -942,7 +764,6 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_ func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error { mux := http.NewServeMux() - mux.HandleFunc("/auth", AuthHandler(ctx, node, gql_ext)) mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext)) diff --git a/gql_test.go b/gql_test.go index a18b1ba..f6817dd 100644 --- a/gql_test.go +++ b/gql_test.go @@ -34,7 +34,10 @@ func TestGQLDB(t * testing.T) { err = ctx.Send(gql.ID, gql.ID, StopSignal) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on listener") + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStateSignal) bool { + return sig.State == "stopped" && sig.ID == gql.ID + }) + fatalErr(t, err) ser1, err := gql.Serialize() ser2, err := u1.Serialize() @@ -49,7 +52,10 @@ func TestGQLDB(t * testing.T) { fatalErr(t, err) err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, listener_ext, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStateSignal) bool { + return sig.State == "stopped" && sig.ID == gql_loaded.ID + }) + fatalErr(t, err) } diff --git a/graph_test.go b/graph_test.go index 973a694..d1c7168 100644 --- a/graph_test.go +++ b/graph_test.go @@ -2,111 +2,10 @@ package graphvent import ( "testing" - "fmt" - "time" - "runtime/pprof" "runtime/debug" - "os" badger "github.com/dgraph-io/badger/v3" ) -type GraphTester testing.T -const listner_timeout = 50 * time.Millisecond - -func (t *GraphTester) WaitForReadResult(ctx *Context, listener *ListenerExt, timeout time.Duration, str string) map[ExtType]map[string]interface{} { - timeout_channel := time.After(timeout) - for true { - select { - case signal := <- listener.Chan: - ctx.Log.Logf("test", "SIGNAL %+v", signal) - if signal == nil { - ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) - t.Fatal(str) - } - if signal.Type() == ReadResultSignalType { - result_signal, ok := signal.(ReadResultSignal) - if ok == false { - ctx.Log.Logf("test", "SIGNAL_CHANNEL_BAD_CAST: %+v", signal) - t.Fatal(str) - } - return result_signal.Extensions - } - case <-timeout_channel: - ctx.Log.Logf("test", "SIGNAL_CHANNEL_TIMEOUT: %+v", listener) - t.Fatal(str) - } - } - return nil -} - -func (t *GraphTester) WaitForState(ctx * Context, listener *ListenerExt, stype SignalType, state string, timeout time.Duration, str string) Signal { - timeout_channel := time.After(timeout) - for true { - select { - case signal := <- listener.Chan: - if signal == nil { - ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) - t.Fatal(str) - } - if signal.Type() == stype { - sig, ok := signal.(StateSignal) - if ok == true { - ctx.Log.Logf("test", "%s state received: %s", stype, sig.State) - if sig.State == state { - return signal - } - } - } - case <-timeout_channel: - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - t.Fatal(str) - return nil - } - } - return nil -} - -func (t * GraphTester) WaitForStatus(ctx * Context, listener *ListenerExt, status string, timeout time.Duration, str string) Signal { - timeout_channel := time.After(timeout) - for true { - select { - case signal := <- listener.Chan: - if signal == nil { - ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) - t.Fatal(str) - } - if signal.Type() == StatusSignalType { - sig, ok := signal.(StatusSignal) - if ok == true { - - - ctx.Log.Logf("test", "Status received: %s", sig.Status) - if sig.Status == status { - return signal - } - } else { - ctx.Log.Logf("test", "Failed to cast status to StatusSignal: %+v", signal) - } - } - case <-timeout_channel: - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - t.Fatal(str) - return nil - } - } - return nil -} - -func (t * GraphTester) CheckForNone(listener *ListenerExt, str string) { - timeout := time.After(listner_timeout) - select { - case sig := <- listener.Chan: - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - t.Fatal(fmt.Sprintf("%s : %+v", str, sig)) - case <-timeout: - } -} - const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { diff --git a/lockable.go b/lockable.go index 5368b71..1dc063e 100644 --- a/lockable.go +++ b/lockable.go @@ -420,8 +420,6 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { - ctx.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) - switch signal.Direction() { case Up: owner_sent := false diff --git a/lockable_test.go b/lockable_test.go index f773d34..dd288e6 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -3,7 +3,6 @@ package graphvent import ( "testing" "time" - "fmt" ) const TestLockableType = NodeType("TEST_LOCKABLE") @@ -41,18 +40,31 @@ func TestLink(t *testing.T) { err := LinkRequirement(ctx, l1.ID, l2.ID) fatalErr(t, err) - (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*10, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l2_listener, LinkSignalType, "linked_as_dep", time.Millisecond*10, "No req_linked") + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool { + return sig.State == "linked_as_req" + }) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool { + return sig.State == "linked_as_dep" + }) + fatalErr(t, err) err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID)) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, l1_listener, "TEST", time.Millisecond*10, "No TEST on l1") - (*GraphTester)(t).WaitForStatus(ctx, l2_listener, "TEST", time.Millisecond*10, "No TEST on l2") + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig IDStateSignal) bool { + return sig.State == "TEST" + }) + fatalErr(t, err) + + _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig IDStateSignal) bool { + return sig.State == "TEST" + }) + fatalErr(t, err) } func TestLink10K(t *testing.T) { - ctx := lockableTestContext(t, []string{"test"}) + ctx := lockableTestContext(t, []string{}) NewLockable := func()(*Node) { l := NewNode(ctx, nil, TestLockableType, 10, nil, @@ -82,8 +94,11 @@ func TestLink10K(t *testing.T) { ctx.Log.Logf("test", "CREATED_10K") - for i, _ := range(lockables) { - (*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*1000, fmt.Sprintf("No linked_as_req for %d", i)) + for range(lockables) { + _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool { + return sig.State == "linked_as_req" + }) + fatalErr(t, err) } ctx.Log.Logf("test", "LINKED_10K") @@ -129,23 +144,44 @@ func TestLock(t *testing.T) { err = LinkRequirement(ctx, l0.ID, l5.ID) fatalErr(t, err) - (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") + linked_as_req := func(sig StateSignal) bool { + return sig.State == "linked_as_req" + } + + locked := func(sig StateSignal) bool { + return sig.State == "locked" + } + + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) - (*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") - (*GraphTester)(t).WaitForState(ctx, l0_listener, LinkSignalType, "linked_as_req", time.Millisecond*100, "No linked_as_req") + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + fatalErr(t, err) err = LockLockable(ctx, l1) fatalErr(t, err) - (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") - (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + fatalErr(t, err) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + fatalErr(t, err) err = UnlockLockable(ctx, l1) fatalErr(t, err) diff --git a/node_test.go b/node_test.go index a1f0ada..64f8317 100644 --- a/node_test.go +++ b/node_test.go @@ -21,9 +21,9 @@ func TestNodeDB(t *testing.T) { } func TestNodeRead(t *testing.T) { - ctx := logTestContext(t, []string{"test", "read", "signal", "policy", "node", "loop"}) + ctx := logTestContext(t, []string{}) node_type := NodeType("TEST") - err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType}) + err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType}) fatalErr(t, err) n1_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) @@ -41,17 +41,54 @@ func TestNodeRead(t *testing.T) { n1_id: Actions{MakeAction(ReadResultSignalType, "+")}, }) n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) n1_policy := NewPerNodePolicy(map[NodeID]Actions{ n2_id: Actions{MakeAction(ReadSignalType, "+")}, }) - n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil)) + n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil), NewECDHExt()) ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, })) - res := (*GraphTester)(t).WaitForReadResult(ctx, n2_listener, 10*time.Millisecond, "No read_result") + res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig ReadResultSignal) bool { + return true + }) + fatalErr(t, err) ctx.Log.Logf("test", "READ_RESULT: %+v", res) } + +func TestECDH(t *testing.T) { + ctx := logTestContext(t, []string{"test", "ecdh", "policy"}) + + node_type := NodeType("TEST") + err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, ECDHExtType}) + fatalErr(t, err) + + n1_listener := NewListenerExt(10) + ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+")}) + n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener) + n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt()) + + ctx.Log.Logf("test", "N1: %s", n1.ID) + ctx.Log.Logf("test", "N2: %s", n2.ID) + + + ecdh_req, n1_ec, err := NewECDHReqSignal(ctx, n1) + ecdh_ext, err := GetExt[*ECDHExt](n1) + fatalErr(t, err) + ecdh_ext.ECDHStates[n2.ID] = ECDHState{ + ECKey: n1_ec, + SharedSecret: nil, + } + fatalErr(t, err) + ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) + err = ctx.Send(n1.ID, n2.ID, ecdh_req) + fatalErr(t, err) + + _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig ECDHSignal) bool { + return sig.State == "resp" + }) + fatalErr(t, err) +} diff --git a/signal.go b/signal.go index a1ba3fb..c031603 100644 --- a/signal.go +++ b/signal.go @@ -1,21 +1,26 @@ package graphvent import ( + "time" + "fmt" "encoding/json" -) - -const ( - StopSignalType = SignalType("STOP") - StatusSignalType = SignalType("STATUS") - LinkSignalType = SignalType("LINK") - LockSignalType = SignalType("LOCK") - ReadSignalType = SignalType("READ") - ReadResultSignalType = SignalType("READ_RESULT") - LinkStartSignalType = SignalType("LINK_START") + "crypto/sha512" + "crypto/ecdsa" + "crypto/ecdh" + "crypto/rand" ) type SignalDirection int const ( + StopSignalType SignalType = "STOP" + StatusSignalType = "STATUS" + LinkSignalType = "LINK" + LockSignalType = "LOCK" + ReadSignalType = "READ" + ReadResultSignalType = "READ_RESULT" + LinkStartSignalType = "LINK_START" + ECDHSignalType = "ECDH" + Up SignalDirection = iota Down Direct @@ -29,10 +34,35 @@ func (signal_type SignalType) String() string { type Signal interface { Serializable[SignalType] Direction() SignalDirection - MarshalJSON() ([]byte, error) Permission() Action } +func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { + var zero S + timeout_channel := time.After(timeout) + for true { + select { + case signal := <- listener.Chan: + if signal == nil { + return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) + } + if signal.Type() == signal_type { + sig, ok := signal.(S) + if ok == true { + ctx.Log.Logf("test", "received: %+v", sig) + if check(sig) == true { + return sig, nil + } + } + } + case <-timeout_channel: + return zero, fmt.Errorf("LISTENER_TIMEOUT: %s", signal_type) + } + } + return zero, fmt.Errorf("LOOP_ENDED") +} + + type BaseSignal struct { SignalDirection SignalDirection `json:"direction"` SignalType SignalType `json:"type"` @@ -50,7 +80,7 @@ func (signal BaseSignal) Direction() SignalDirection { return signal.SignalDirection } -func (signal BaseSignal) MarshalJSON() ([]byte, error) { +func (signal *BaseSignal) MarshalJSON() ([]byte, error) { return json.Marshal(signal) } @@ -100,12 +130,18 @@ func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) I } } -type StatusSignal struct { - IDSignal - Status string `json:"status"` +type StateSignal struct { + BaseSignal + State string `json:"state"` } -func (signal StatusSignal) String() string { +type IDStateSignal struct { + BaseSignal + ID NodeID `json:"id"` + State string `json:"status"` +} + +func (signal IDStateSignal) String() string { ser, err := json.Marshal(signal) if err != nil { return "STATE_SER_ERR" @@ -113,18 +149,14 @@ func (signal StatusSignal) String() string { return string(ser) } -func NewStatusSignal(status string, source NodeID) StatusSignal { - return StatusSignal{ - IDSignal: NewIDSignal(StatusSignalType, Up, source), - Status: status, +func NewStatusSignal(status string, source NodeID) IDStateSignal { + return IDStateSignal{ + BaseSignal: NewUpSignal(StatusSignalType), + ID: source, + State: status, } } -type StateSignal struct { - BaseSignal - State string `json:"state"` -} - func (signal StateSignal) Serialize() ([]byte, error) { return json.MarshalIndent(signal, "", " ") } @@ -188,4 +220,118 @@ func NewReadResultSignal(exts map[ExtType]map[string]interface{}) ReadResultSign } } +type ECDHSignal struct { + StateSignal + Time time.Time + ECDSA *ecdsa.PublicKey + ECDH *ecdh.PublicKey + Signature []byte +} + +func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { + time_bytes, err := now.MarshalJSON() + if err != nil { + return nil, err + } + + sig_data := append(ec_key.Bytes(), time_bytes...) + sig_hash := sha512.Sum512(sig_data) + + return sig_hash[:], nil +} + +func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) { + ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) + if err != nil { + return ECDHSignal{}, nil, err + } + + now := time.Now() + + sig_hash, err := keyHash(now, ec_key.PublicKey()) + if err != nil { + return ECDHSignal{}, nil, err + } + + sig, err := ecdsa.SignASN1(rand.Reader, node.Key, sig_hash) + if err != nil { + return ECDHSignal{}, nil, err + } + + return ECDHSignal{ + StateSignal: StateSignal{ + BaseSignal: NewDirectSignal(ECDHSignalType), + State: "req", + }, + Time: now, + ECDSA: &node.Key.PublicKey, + ECDH: ec_key.PublicKey(), + Signature: sig, + }, ec_key, nil +} + +const DEFAULT_ECDH_WINDOW = time.Second + +func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []byte, error) { + now := time.Now() + + err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) + if err != nil { + return ECDHSignal{}, nil, err + } + + ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) + if err != nil { + return ECDHSignal{}, nil, err + } + + shared_secret, err := ec_key.ECDH(req.ECDH) + if err != nil { + return ECDHSignal{}, nil, err + } + + key_hash, err := keyHash(now, ec_key.PublicKey()) + if err != nil { + return ECDHSignal{}, nil, err + } + + sig, err := ecdsa.SignASN1(rand.Reader, node.Key, key_hash) + if err != nil { + return ECDHSignal{}, nil, err + } + + return ECDHSignal{ + StateSignal: StateSignal{ + BaseSignal: NewDirectSignal(ECDHSignalType), + State: "resp", + }, + Time: now, + ECDSA: &node.Key.PublicKey, + ECDH: ec_key.PublicKey(), + Signature: sig, + }, shared_secret, nil +} + +func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error { + earliest := now.Add(-window) + latest := now.Add(window) + + if sig.Time.Compare(earliest) == -1 { + return fmt.Errorf("TIME_TOO_LATE: %+v", sig.Time) + } else if sig.Time.Compare(latest) == 1 { + return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time) + } + + sig_hash, err := keyHash(sig.Time, sig.ECDH) + if err != nil { + return err + } + + verified := ecdsa.VerifyASN1(sig.ECDSA, sig_hash, sig.Signature) + if verified == false { + return fmt.Errorf("VERIFY_FAIL") + } + + return nil +} diff --git a/user.go b/user.go index f2dc16f..f8f87ba 100644 --- a/user.go +++ b/user.go @@ -1,17 +1,89 @@ package graphvent import ( - "time" "fmt" + "time" "encoding/json" "crypto/ecdsa" "crypto/x509" + "crypto/ecdh" ) +type ECDHState struct { + ECKey *ecdh.PrivateKey + SharedSecret []byte +} + +type ECDHStateJSON struct { + ECKey []byte `json:"ec_key"` + SharedSecret []byte `json:"shared_secret"` +} + +func (state *ECDHState) MarshalJSON() ([]byte, error) { + var key_bytes []byte + var err error + if state.ECKey != nil { + key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey) + if err != nil { + return nil, err + } + } + + return json.Marshal(&ECDHStateJSON{ + ECKey: key_bytes, + SharedSecret: state.SharedSecret, + }) +} + +func (state *ECDHState) UnmarshalJSON(data []byte) error { + var j ECDHStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return err + } + + state.SharedSecret = j.SharedSecret + if len(j.ECKey) == 0 { + state.ECKey = nil + } else { + tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey) + if err != nil { + return err + } + + ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey) + if ok == false { + return fmt.Errorf("Parsed wrong key type from DB for ECDHState") + } + + state.ECKey, err = ecdsa_key.ECDH() + if err != nil { + return err + } + } + + return nil +} + +type ECDHMap map[NodeID]ECDHState + +func (m ECDHMap) MarshalJSON() ([]byte, error) { + tmp := map[string]ECDHState{} + for id, state := range(m) { + tmp[id.String()] = state + } + + return json.Marshal(tmp) +} + type ECDHExt struct { - Granted time.Time - Pubkey *ecdsa.PublicKey - Shared []byte + ECDHStates ECDHMap +} + +func NewECDHExt() *ECDHExt { + return &ECDHExt{ + ECDHStates: ECDHMap{}, + } } func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} { @@ -25,26 +97,72 @@ func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T) func (ext *ECDHExt) Field(name string) interface{} { return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{ - "granted": func(ext *ECDHExt) interface{} { - return ext.Granted - }, - "pubkey": func(ext *ECDHExt) interface{} { - return ext.Pubkey - }, - "shared": func(ext *ECDHExt) interface{} { - return ext.Shared + "ecdh_states": func(ext *ECDHExt) interface{} { + return ext.ECDHStates }, }) } -type ECDHExtJSON struct { - Granted time.Time `json:"granted"` - Pubkey []byte `json:"pubkey"` - Shared []byte `json:"shared"` +func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) { + ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %+v", source, node, signal) + switch signal.State { + case "req": + state, exists := ext.ECDHStates[source] + if exists == false { + state = ECDHState{nil, nil} + } + resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal) + if err == nil { + state.SharedSecret = shared_secret + ext.ECDHStates[source] = state + ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + ctx.Send(node.ID, source, resp) + } else { + ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) + // TODO: send error response + } + case "resp": + state, exists := ext.ECDHStates[source] + if exists == false || state.ECKey == nil { + ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHSignalType), "no_req"}) + } else { + err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) + if err == nil { + shared_secret, err := state.ECKey.ECDH(signal.ECDH) + if err == nil { + state.SharedSecret = shared_secret + state.ECKey = nil + ext.ECDHStates[source] = state + ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + } + } + } + default: + ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State) + } } -func (ext *ECDHExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { - return +func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) { + +} + +func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { + switch signal.Direction() { + case Direct: + switch signal.Type() { + case ECDHSignalType: + switch ecdh_signal := signal.(type) { + case ECDHSignal: + ext.HandleECDHSignal(ctx, source, node, ecdh_signal) + case StateSignal: + ext.HandleStateSignal(ctx, source, node, ecdh_signal) + default: + ctx.Log.Logf("ecdh", "BAD_SIGNAL_CAST: %+v", signal) + } + default: + } + default: + } } func (ext *ECDHExt) Type() ExtType { @@ -52,45 +170,17 @@ func (ext *ECDHExt) Type() ExtType { } func (ext *ECDHExt) Serialize() ([]byte, error) { - pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey) - if err != nil { - return nil, err - } - - return json.MarshalIndent(&ECDHExtJSON{ - Granted: ext.Granted, - Pubkey: pubkey, - Shared: ext.Shared, - }, "", " ") + return json.MarshalIndent(ext, "", " ") } func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { - var j ECDHExtJSON - err := json.Unmarshal(data, &j) + var ext ECDHExt + err := json.Unmarshal(data, &ext) if err != nil { return nil, err } - pub, err := x509.ParsePKIXPublicKey(j.Pubkey) - if err != nil { - return nil, err - } - - var pubkey *ecdsa.PublicKey - switch pub.(type) { - case *ecdsa.PublicKey: - pubkey = pub.(*ecdsa.PublicKey) - default: - return nil, fmt.Errorf("Invalid key type: %+v", pub) - } - - extension := ECDHExt{ - Granted: j.Granted, - Pubkey: pubkey, - Shared: j.Shared, - } - - return &extension, nil + return &ext, nil } type GroupExt struct {