From d4fb5a0922dfbad8f92caaa4e28c88f44960beca Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 21 Jul 2023 01:05:24 -0600 Subject: [PATCH] Added TLS to GQLThread --- gql.go | 31 ++++++++++++++++++++++++---- gql_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/gql.go b/gql.go index 3934d26..3856be2 100644 --- a/gql.go +++ b/gql.go @@ -24,6 +24,7 @@ import ( "crypto/sha512" "crypto/rand" "crypto/x509" + "crypto/tls" ) type AuthReqJSON struct { @@ -608,6 +609,8 @@ type GQLThread struct { tcp_listener net.Listener http_server *http.Server http_done *sync.WaitGroup + ssl_key []byte + ssl_cert []byte Listen string Users map[NodeID]*User Key *ecdsa.PrivateKey @@ -697,7 +700,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return nil, err } - thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) + thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, nil, nil) thread.Users = map[NodeID]*User{} for _, id_str := range(j.Users) { id, err := ParseID(id_str) @@ -720,7 +723,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return &thread, nil } -func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey) GQLThread { +func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, ssl_cert []byte, ssl_key []byte) GQLThread { return GQLThread{ SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Listen: listen, @@ -728,6 +731,8 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh http_done: &sync.WaitGroup{}, Key: key, ECDH: ecdh_curve, + ssl_cert: ssl_cert, + ssl_key: ssl_key, } } @@ -773,17 +778,35 @@ var gql_actions ThreadActions = ThreadActions{ Handler: mux, } - listener, err := net.Listen("tcp", http_server.Addr) + var listener net.Listener + l, err := net.Listen("tcp", http_server.Addr) if err != nil { return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) + } + listener = l + + if server.ssl_cert != nil { + ser, _ := json.Marshal(server.ssl_cert) + ctx.Log.Logf("gql", "SSL_CERT: %s", ser) + cert, err := tls.X509KeyPair(server.ssl_cert, server.ssl_key) + if err != nil { + return "", err + } + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"http/1.1"}, + } + + listener = tls.NewListener(l, &config) } + server.http_done.Add(1) go func(server *GQLThread) { defer server.http_done.Done() - err = http_server.Serve(listener) + err := http_server.Serve(listener) if err != http.ErrServerClosed { panic(fmt.Sprintf("Failed to start gql server: %s", err)) } diff --git a/gql_test.go b/gql_test.go index 29dfb03..1ff4026 100644 --- a/gql_test.go +++ b/gql_test.go @@ -9,19 +9,24 @@ import ( "io" "fmt" "encoding/json" + "encoding/pem" "bytes" "crypto/rand" "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" + "crypto/x509" + "crypto/x509/pkix" + "crypto/tls" "encoding/base64" + "math/big" ) func TestGQLThread(t * testing.T) { ctx := logTestContext(t, []string{}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) @@ -75,7 +80,7 @@ func TestGQLDBLoad(t * testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) + gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") @@ -154,7 +159,38 @@ func TestGQLAuth(t * testing.T) { ctx := logTestContext(t, []string{"test", "gql"}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + fatalErr(t, err) + + notBefore := time.Now() + notAfter := notBefore.Add(365*24*time.Hour) + + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"mekkanized"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) + fatalErr(t, err) + ssl_cert_bytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert}) + + ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) + fatalErr(t, err) + ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes}) + + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, ssl_cert_bytes, ssl_key_pem) gql_t := &gql_t_r done := make(chan error, 1) @@ -184,8 +220,18 @@ func TestGQLAuth(t * testing.T) { (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start") port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port ctx.Log.Logf("test", "GQL_PORT: %d", port) - client := &http.Client{} - url := fmt.Sprintf("http://localhost:%d/auth", port) + + customTransport := &http.Transport{ + Proxy: http.DefaultTransport.(*http.Transport).Proxy, + DialContext: http.DefaultTransport.(*http.Transport).DialContext, + MaxIdleConns: http.DefaultTransport.(*http.Transport).MaxIdleConns, + IdleConnTimeout: http.DefaultTransport.(*http.Transport).IdleConnTimeout, + ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout, + TLSHandshakeTimeout: http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: customTransport} + url := fmt.Sprintf("https://localhost:%d/auth", port) id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) @@ -215,7 +261,7 @@ func TestGQLAuth(t * testing.T) { shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key) fatalErr(t, err) - url = fmt.Sprintf("http://localhost:%d/gql", port) + url = fmt.Sprintf("https://localhost:%d/gql", port) ser, err := json.MarshalIndent(&GQLPayload{ Query: "query { Self { Users { ID } } }", }, "", " ")