diff --git a/gql.go b/gql.go index c99fc2d..c3f28f3 100644 --- a/gql.go +++ b/gql.go @@ -25,6 +25,9 @@ import ( "crypto/rand" "crypto/x509" "crypto/tls" + "crypto/x509/pkix" + "math/big" + "encoding/pem" ) type AuthReqJSON struct { @@ -609,8 +612,8 @@ type GQLThread struct { tcp_listener net.Listener http_server *http.Server http_done *sync.WaitGroup - ssl_key []byte - ssl_cert []byte + tls_key []byte + tls_cert []byte Listen string Users map[NodeID]*User Key *ecdsa.PrivateKey @@ -641,8 +644,8 @@ type GQLThreadJSON struct { Users []string `json:"users"` Key []byte `json:"key"` ECDH uint8 `json:"ecdh_curve"` - SSLKey []byte `json:"ssl_key"` - SSLCert []byte `json:"ssl_cert"` + TLSKey []byte `json:"ssl_key"` + TLSCert []byte `json:"ssl_cert"` } var ecdsa_curves = map[uint8]elliptic.Curve{ @@ -682,8 +685,8 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { Users: users, Key: ser_key, ECDH: ecdh_curve_ids[thread.ECDH], - SSLKey: thread.ssl_key, - SSLCert: thread.ssl_cert, + TLSKey: thread.tls_key, + TLSCert: thread.tls_cert, } } @@ -704,7 +707,7 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return nil, err } - thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.SSLCert, j.SSLKey) + thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) thread.Users = map[NodeID]*User{} for _, id_str := range(j.Users) { id, err := ParseID(id_str) @@ -727,7 +730,46 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e return &thread, nil } -func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, ssl_cert []byte, ssl_key []byte) GQLThread { +func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread { + if tls_cert == nil || tls_key == nil { + ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) + if err != nil { + panic(err) + } + + ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) + if err != nil { + panic(err) + } + + ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes}) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + notBefore := time.Now() + notAfter := notBefore.Add(365*24*time.Hour) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"mekkanized"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) + if err != nil { + panic(err) + } + + ssl_cert_pem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert}) + + tls_cert = ssl_cert_pem + tls_key = ssl_key_pem + } return GQLThread{ SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Listen: listen, @@ -735,8 +777,8 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh http_done: &sync.WaitGroup{}, Key: key, ECDH: ecdh_curve, - ssl_cert: ssl_cert, - ssl_key: ssl_key, + tls_cert: tls_cert, + tls_key: tls_key, } } @@ -782,29 +824,22 @@ var gql_actions ThreadActions = ThreadActions{ Handler: mux, } - var listener net.Listener l, err := net.Listen("tcp", http_server.Addr) if err != nil { return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) } - listener = l - if server.ssl_cert != nil { - ser, _ := json.Marshal(server.ssl_cert) - ctx.Log.Logf("gql", "SSL_CERT: %s", ser) - cert, err := tls.X509KeyPair(server.ssl_cert, server.ssl_key) - if err != nil { - return "", err - } - - config := tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{"http/1.1"}, - } + cert, err := tls.X509KeyPair(server.tls_cert, server.tls_key) + if err != nil { + return "", err + } - listener = tls.NewListener(l, &config) + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"http/1.1"}, } + listener := tls.NewListener(l, &config) server.http_done.Add(1) go func(server *GQLThread) { diff --git a/gql_test.go b/gql_test.go index 1ff4026..7944f2c 100644 --- a/gql_test.go +++ b/gql_test.go @@ -9,17 +9,13 @@ import ( "io" "fmt" "encoding/json" - "encoding/pem" "bytes" "crypto/rand" "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" - "crypto/x509" - "crypto/x509/pkix" "crypto/tls" "encoding/base64" - "math/big" ) func TestGQLThread(t * testing.T) { @@ -160,37 +156,7 @@ func TestGQLAuth(t * testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - fatalErr(t, err) - - notBefore := time.Now() - notAfter := notBefore.Add(365*24*time.Hour) - - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"mekkanized"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) - fatalErr(t, err) - ssl_cert_bytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert}) - - ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) - fatalErr(t, err) - ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes}) - - gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, ssl_cert_bytes, ssl_key_pem) + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r done := make(chan error, 1)