diff --git a/gql.go b/gql.go index 8dddfb8..2ba6391 100644 --- a/gql.go +++ b/gql.go @@ -1,6 +1,8 @@ package graphvent import ( + "time" + "net" "net/http" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/parser" @@ -15,8 +17,179 @@ import ( "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "strings" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/sha512" + "crypto/rand" + "crypto/x509" ) +type AuthReqJSON struct { + Time time.Time `json:"time"` + Pubkey []byte `json:"pubkey"` + ECDHPubkey []byte `json:"ecdh_client"` + Signature []byte `json:"signature"` +} + +func NewAuthReqJSON(curve ecdh.Curve, id *ecdsa.PrivateKey) (AuthReqJSON, *ecdh.PrivateKey, error) { + ec_key, err := curve.GenerateKey(rand.Reader) + if err != nil { + return AuthReqJSON{}, nil, err + } + now := time.Now() + time_bytes, err := now.MarshalJSON() + if err != nil { + return AuthReqJSON{}, nil, err + } + sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) + sig_hash := sha512.Sum512(sig_data) + sig, err := ecdsa.SignASN1(rand.Reader, id, sig_hash[:]) + + id_ecdh, err := id.ECDH() + if err != nil { + return AuthReqJSON{}, nil, err + } + + return AuthReqJSON{ + Time: now, + Pubkey: id_ecdh.PublicKey().Bytes(), + ECDHPubkey: ec_key.PublicKey().Bytes(), + Signature: sig, + }, ec_key, nil +} + +type AuthRespJSON struct { + Granted time.Time `json:"granted"` + ECDHPubkey []byte `json:"echd_server"` +} + +func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, error) { + // Check if req.Time is within +- 1 second of now + now := time.Now() + earliest := now.Add(-1 * time.Second) + latest := now.Add(1 * time.Second) + // If req.Time is before the earliest acceptable time, or after the latest acceptible time + if req.Time.Compare(earliest) == -1 { + return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time) + } else if req.Time.Compare(latest) == 1 { + return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) + } + + x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey) + if x == nil { + return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) + } + + remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey) + if err != nil { + return AuthRespJSON{}, nil, err + } + + // Verify the signature + time_bytes, _ := req.Time.MarshalJSON() + sig_data := append(req.ECDHPubkey, time_bytes...) + sig_hash := sha512.Sum512(sig_data) + + verified := ecdsa.VerifyASN1( + &ecdsa.PublicKey{ + Curve: thread.Key.Curve, + X: x, + Y: y, + }, + sig_hash[:], + req.Signature, + ) + + if verified == false { + return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) + } + + ec_key, err := thread.ECDH.GenerateKey(rand.Reader) + if err != nil { + return AuthRespJSON{}, nil, err + } + + shared_secret, err := ec_key.ECDH(remote) + if err != nil { + return AuthRespJSON{}, nil, err + } + + return AuthRespJSON{ + Granted: time.Now(), + ECDHPubkey: ec_key.PublicKey().Bytes(), + }, shared_secret, nil +} + +type AuthData struct { + Granted time.Time + Pubkey ecdh.PublicKey + ECDHClient ecdh.PublicKey +} + +type AuthDataJSON struct { + Granted time.Time `json:"granted"` + Pubkey []byte `json:"pbkey"` + ECDHClient []byte `json:"ecdh_client"` +} + +func HashKey(pub []byte) uint64 { + return 0 +} + +func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr) + enableCORS(&w) + + str, err := io.ReadAll(r.Body) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_READ_ERR: %e", err) + return + } + + var req AuthReqJSON + err = json.Unmarshal([]byte(str), &req) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_UNMARHSHAL_ERR: %e", err) + return + } + + resp, _, err := NewAuthRespJSON(server, req) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err) + return + } + + ser, err := json.Marshal(resp) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_RESP_MARSHAL_ERR: %e", err) + return + } + + wrote, err := w.Write(ser) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_RESP_ERR: %e", err) + return + } else if wrote != len(ser) { + ctx.Log.Logf("gql", "GQL_AUTH_RESP_BAD_LENGTH: %d/%d", wrote, len(ser)) + return + } + + ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str) + + key_hash := HashKey(req.Pubkey) + + _, exists := server.AuthMap[key_hash] + if exists { + // New user + } else { + // Existing user + } + + } +} + func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { graphiql_string := fmt.Sprintf(` @@ -166,7 +339,6 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") auth, ok := checkForAuthHeader(r.Header) if ok == false { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header") json.NewEncoder(w).Encode(GQLUnauthorized("No TM Auth header provided")) return @@ -388,9 +560,13 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * type GQLThread struct { SimpleThread + tcp_listener net.Listener http_server *http.Server http_done *sync.WaitGroup Listen string + AuthMap map[uint64]AuthData + Key *ecdsa.PrivateKey + ECDH ecdh.Curve } func (thread * GQLThread) Type() NodeType { @@ -414,14 +590,41 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo type GQLThreadJSON struct { SimpleThreadJSON Listen string `json:"listen"` + AuthMap map[uint64]AuthData `json:"auth_map"` + Key []byte `json:"key"` + ECDH uint8 `json:"ecdh_curve"` +} + +var ecdsa_curves = map[uint8]elliptic.Curve{ + 0: elliptic.P256(), +} + +var ecdsa_curve_ids = map[elliptic.Curve]uint8{ + elliptic.P256(): 0, +} + +var ecdh_curves = map[uint8]ecdh.Curve{ + 0: ecdh.P256(), +} + +var ecdh_curve_ids = map[ecdh.Curve]uint8{ + ecdh.P256(): 0, } func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { thread_json := NewSimpleThreadJSON(&thread.SimpleThread) + ser_key, err := x509.MarshalECPrivateKey(thread.Key) + if err != nil { + panic(err) + } + return GQLThreadJSON{ SimpleThreadJSON: thread_json, Listen: thread.Listen, + AuthMap: thread.AuthMap, + Key: ser_key, + ECDH: ecdh_curve_ids[thread.ECDH], } } @@ -432,7 +635,18 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return nil, err } - thread := NewGQLThread(id, j.Name, j.StateName, j.Listen) + ecdh_curve, ok := ecdh_curves[j.ECDH] + if ok == false { + return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) + } + + key, err := x509.ParseECPrivateKey(j.Key) + if err != nil { + return nil, err + } + + thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) + thread.AuthMap = j.AuthMap nodes[id] = &thread err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) @@ -443,11 +657,14 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return &thread, nil } -func NewGQLThread(id NodeID, name string, state_name string, listen string) GQLThread { +func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey) GQLThread { return GQLThread{ SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Listen: listen, + AuthMap: map[uint64]AuthData{}, http_done: &sync.WaitGroup{}, + Key: key, + ECDH: ecdh_curve, } } @@ -477,6 +694,7 @@ var gql_actions ThreadActions = ThreadActions{ ctx.Log.Logf("gql", "GQL_START_SERVER") // Serve the GQL http and ws handlers mux := http.NewServeMux() + mux.HandleFunc("/auth", AuthHandler(ctx, server)) mux.HandleFunc("/gql", GQLHandler(ctx, server)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server)) @@ -487,23 +705,34 @@ var gql_actions ThreadActions = ThreadActions{ fs := http.FileServer(http.Dir("./site")) mux.Handle("/site/", http.StripPrefix("/site", fs)) - UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){ - server.http_server = &http.Server{ - Addr: server.Listen, - Handler: mux, - } - return nil - }) + http_server := &http.Server{ + Addr: server.Listen, + Handler: mux, + } + + listener, err := net.Listen("tcp", http_server.Addr) + if err != nil { + return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) + + } server.http_done.Add(1) go func(server *GQLThread) { defer server.http_done.Done() - err := server.http_server.ListenAndServe() + + err = http_server.Serve(listener) if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) + panic(fmt.Sprintf("Failed to start gql server: %s", err)) } }(server) + + UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){ + server.tcp_listener = listener + server.http_server = http_server + return server.Signal(ctx, NewSignal(server, "server_started"), nodes) + }) + return "wait", nil }, } diff --git a/gql_test.go b/gql_test.go index f9ca443..a1e997c 100644 --- a/gql_test.go +++ b/gql_test.go @@ -3,13 +3,24 @@ package graphvent import ( "testing" "time" - "fmt" "errors" + "net" + "net/http" + "io" + "fmt" + "encoding/json" + "bytes" + "crypto/rand" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" ) func TestGQLThread(t * testing.T) { ctx := logTestContext(t, []string{}) - gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0") + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) gql_t := &gql_t_r t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) @@ -17,7 +28,7 @@ func TestGQLThread(t * testing.T) { t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) t2 := &t2_r - err := UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error { + err = UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error { i1 := NewParentThreadInfo(true, "start", "restore") err := LinkThreads(ctx, gql_t, t1, &i1, nodes) if err != nil { @@ -42,7 +53,7 @@ func TestGQLThread(t * testing.T) { } func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"thread", "signal", "gql", "test"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -50,11 +61,13 @@ func TestGQLDBLoad(t * testing.T) { t1 := &t1_r update_channel := UpdateChannel(t1, 10, "test") - gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":8080") + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") - err := UpdateStates(ctx, []Node{gql, t1, l1}, func(nodes NodeMap) error { + err = UpdateStates(ctx, []Node{gql, t1, l1}, func(nodes NodeMap) error { err := LinkLockables(ctx, gql, []Lockable{l1}, nodes) if err != nil { return err @@ -84,8 +97,8 @@ func TestGQLDBLoad(t * testing.T) { err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error { ser1, err := gql.Serialize() ser2, err := t1.Serialize() - fmt.Printf("\n%s\n\n", ser1) - fmt.Printf("\n%s\n\n", ser2) + ctx.Log.Logf("thread", "\n%s\n\n", ser1) + ctx.Log.Logf("thread", "\n%s\n\n", ser2) return err }) @@ -96,13 +109,13 @@ func TestGQLDBLoad(t * testing.T) { var update_channel_2 chan GraphSignal err = UseStates(ctx, []Node{gql_loaded}, func(nodes NodeMap) error { ser, err := gql_loaded.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child update_channel_2 = UpdateChannel(t1_loaded, 10, "test") err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { ser, err := child.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) gql_loaded.Signal(ctx, AbortSignal(nil), nodes) @@ -118,3 +131,65 @@ func TestGQLDBLoad(t * testing.T) { (*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", t1_loaded, 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded") } + +func TestGQLAuth(t * testing.T) { + ctx := logTestContext(t, []string{"test", "gql"}) + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) + gql_t := &gql_t_r + + done := make(chan error, 1) + + var update_channel chan GraphSignal + err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { + update_channel = UpdateChannel(gql_t, 10, "test") + return nil + }) + fatalErr(t, err) + + go func(done chan error, thread Thread) { + timeout := time.After(2*time.Second) + select { + case <-timeout: + ctx.Log.Logf("test", "TIMEOUT") + case <-done: + ctx.Log.Logf("test", "DONE") + } + err := UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { + return thread.Signal(ctx, CancelSignal(nil), nodes) + }) + fatalErr(t, err) + }(done, gql_t) + + go func(thread Thread){ + (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start") + port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port + ctx.Log.Logf("test", "GQL_PORT: %d", port) + client := &http.Client{} + url := fmt.Sprintf("http://localhost:%d/auth", port) + + id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + + auth_req, _, err := NewAuthReqJSON(ecdh.P256(), id) + fatalErr(t, err) + + str, err := json.Marshal(auth_req) + fatalErr(t, err) + b := bytes.NewBuffer(str) + req, err := http.NewRequest("PUT", url, b) + fatalErr(t, err) + req.Header.Add("Authorization", "TM baddata") + resp, err := client.Do(req) + fatalErr(t, err) + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + fatalErr(t, err) + ctx.Log.Logf("test", "RESP_BODY: %s", body) + done <- nil + }(gql_t) + + err = ThreadLoop(ctx, gql_t, "start") + fatalErr(t, err) +} diff --git a/lockable_test.go b/lockable_test.go index 27ffef5..1fb6726 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -298,7 +298,7 @@ func TestLockableLockTieredConflict(t * testing.T) { } func TestLockableSimpleUpdate(t * testing.T) { - ctx := logTestContext(t, []string{"test", "update", "lockable"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -316,7 +316,7 @@ func TestLockableSimpleUpdate(t * testing.T) { } func TestLockableDownUpdate(t * testing.T) { - ctx := logTestContext(t, []string{"test", "update", "lockable"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -345,7 +345,7 @@ func TestLockableDownUpdate(t * testing.T) { } func TestLockableUpUpdate(t * testing.T) { - ctx := logTestContext(t, []string{"test", "update", "lockable"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -374,7 +374,7 @@ func TestLockableUpUpdate(t * testing.T) { } func TestOwnerNotUpdatedTwice(t * testing.T) { - ctx := logTestContext(t, []string{"test", "signal", "lockable", "listeners"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r @@ -461,7 +461,7 @@ func TestLockableDBLoad(t * testing.T){ err = UseStates(ctx, []Node{l3}, func(nodes NodeMap) error { ser, err := l3.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) fatalErr(t, err) @@ -472,14 +472,14 @@ func TestLockableDBLoad(t * testing.T){ // TODO: add more equivalence checks err = UseStates(ctx, []Node{l3_loaded}, func(nodes NodeMap) error { ser, err := l3_loaded.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) fatalErr(t, err) } func TestLockableUnlink(t * testing.T){ - ctx := logTestContext(t, []string{"lockable"}) + ctx := logTestContext(t, []string{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r l2_r := NewSimpleLockable(RandID(), "Test Lockable 2") diff --git a/thread_test.go b/thread_test.go index ad87732..0e96c0a 100644 --- a/thread_test.go +++ b/thread_test.go @@ -87,7 +87,7 @@ func TestThreadDBLoad(t * testing.T) { err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { ser, err := t1.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) @@ -96,7 +96,7 @@ func TestThreadDBLoad(t * testing.T) { err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error { ser, err := t1_loaded.Serialize() - fmt.Printf("\n%s\n\n", ser) + ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) }