Added TLS to GQLThread

graph-rework-2
noah metz 2023-07-21 01:05:24 -06:00
parent 7435728321
commit d4fb5a0922
2 changed files with 79 additions and 10 deletions

@ -24,6 +24,7 @@ import (
"crypto/sha512" "crypto/sha512"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/tls"
) )
type AuthReqJSON struct { type AuthReqJSON struct {
@ -608,6 +609,8 @@ type GQLThread struct {
tcp_listener net.Listener tcp_listener net.Listener
http_server *http.Server http_server *http.Server
http_done *sync.WaitGroup http_done *sync.WaitGroup
ssl_key []byte
ssl_cert []byte
Listen string Listen string
Users map[NodeID]*User Users map[NodeID]*User
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
@ -697,7 +700,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
return nil, err return nil, err
} }
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, nil, nil)
thread.Users = map[NodeID]*User{} thread.Users = map[NodeID]*User{}
for _, id_str := range(j.Users) { for _, id_str := range(j.Users) {
id, err := ParseID(id_str) id, err := ParseID(id_str)
@ -720,7 +723,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
return &thread, nil return &thread, nil
} }
func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey) GQLThread { func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, ssl_cert []byte, ssl_key []byte) GQLThread {
return GQLThread{ return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers),
Listen: listen, Listen: listen,
@ -728,6 +731,8 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
http_done: &sync.WaitGroup{}, http_done: &sync.WaitGroup{},
Key: key, Key: key,
ECDH: ecdh_curve, ECDH: ecdh_curve,
ssl_cert: ssl_cert,
ssl_key: ssl_key,
} }
} }
@ -773,17 +778,35 @@ var gql_actions ThreadActions = ThreadActions{
Handler: mux, Handler: mux,
} }
listener, err := net.Listen("tcp", http_server.Addr) var listener net.Listener
l, err := net.Listen("tcp", http_server.Addr)
if err != nil { if err != nil {
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
listener = l
if server.ssl_cert != nil {
ser, _ := json.Marshal(server.ssl_cert)
ctx.Log.Logf("gql", "SSL_CERT: %s", ser)
cert, err := tls.X509KeyPair(server.ssl_cert, server.ssl_key)
if err != nil {
return "", err
}
config := tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"http/1.1"},
}
listener = tls.NewListener(l, &config)
} }
server.http_done.Add(1) server.http_done.Add(1)
go func(server *GQLThread) { go func(server *GQLThread) {
defer server.http_done.Done() defer server.http_done.Done()
err = http_server.Serve(listener) err := http_server.Serve(listener)
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err)) panic(fmt.Sprintf("Failed to start gql server: %s", err))
} }

@ -9,19 +9,24 @@ import (
"io" "io"
"fmt" "fmt"
"encoding/json" "encoding/json"
"encoding/pem"
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/ecdh" "crypto/ecdh"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/x509"
"crypto/x509/pkix"
"crypto/tls"
"encoding/base64" "encoding/base64"
"math/big"
) )
func TestGQLThread(t * testing.T) { func TestGQLThread(t * testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql_t := &gql_t_r gql_t := &gql_t_r
t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
@ -75,7 +80,7 @@ func TestGQLDBLoad(t * testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
gql := &gql_r gql := &gql_r
info := NewParentThreadInfo(true, "start", "restore") info := NewParentThreadInfo(true, "start", "restore")
@ -154,7 +159,38 @@ func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"test", "gql"}) ctx := logTestContext(t, []string{"test", "gql"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
fatalErr(t, err)
notBefore := time.Now()
notAfter := notBefore.Add(365*24*time.Hour)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"mekkanized"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key)
fatalErr(t, err)
ssl_cert_bytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert})
ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key)
fatalErr(t, err)
ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes})
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, ssl_cert_bytes, ssl_key_pem)
gql_t := &gql_t_r gql_t := &gql_t_r
done := make(chan error, 1) done := make(chan error, 1)
@ -184,8 +220,18 @@ func TestGQLAuth(t * testing.T) {
(*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start") (*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port) ctx.Log.Logf("test", "GQL_PORT: %d", port)
client := &http.Client{}
url := fmt.Sprintf("http://localhost:%d/auth", port) customTransport := &http.Transport{
Proxy: http.DefaultTransport.(*http.Transport).Proxy,
DialContext: http.DefaultTransport.(*http.Transport).DialContext,
MaxIdleConns: http.DefaultTransport.(*http.Transport).MaxIdleConns,
IdleConnTimeout: http.DefaultTransport.(*http.Transport).IdleConnTimeout,
ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
TLSHandshakeTimeout: http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: customTransport}
url := fmt.Sprintf("https://localhost:%d/auth", port)
id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
@ -215,7 +261,7 @@ func TestGQLAuth(t * testing.T) {
shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key) shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key)
fatalErr(t, err) fatalErr(t, err)
url = fmt.Sprintf("http://localhost:%d/gql", port) url = fmt.Sprintf("https://localhost:%d/gql", port)
ser, err := json.MarshalIndent(&GQLPayload{ ser, err := json.MarshalIndent(&GQLPayload{
Query: "query { Self { Users { ID } } }", Query: "query { Self { Users { ID } } }",
}, "", " ") }, "", " ")