Started to add authentication for GQL endpoint

graph-rework-2
noah metz 2023-07-19 14:45:05 -06:00
parent d227331fe8
commit 84af718071
4 changed files with 335 additions and 31 deletions

245
gql.go

@ -1,6 +1,8 @@
package graphvent
import (
"time"
"net"
"net/http"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/parser"
@ -15,8 +17,179 @@ import (
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"strings"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/sha512"
"crypto/rand"
"crypto/x509"
)
type AuthReqJSON struct {
Time time.Time `json:"time"`
Pubkey []byte `json:"pubkey"`
ECDHPubkey []byte `json:"ecdh_client"`
Signature []byte `json:"signature"`
}
func NewAuthReqJSON(curve ecdh.Curve, id *ecdsa.PrivateKey) (AuthReqJSON, *ecdh.PrivateKey, error) {
ec_key, err := curve.GenerateKey(rand.Reader)
if err != nil {
return AuthReqJSON{}, nil, err
}
now := time.Now()
time_bytes, err := now.MarshalJSON()
if err != nil {
return AuthReqJSON{}, nil, err
}
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig_hash := sha512.Sum512(sig_data)
sig, err := ecdsa.SignASN1(rand.Reader, id, sig_hash[:])
id_ecdh, err := id.ECDH()
if err != nil {
return AuthReqJSON{}, nil, err
}
return AuthReqJSON{
Time: now,
Pubkey: id_ecdh.PublicKey().Bytes(),
ECDHPubkey: ec_key.PublicKey().Bytes(),
Signature: sig,
}, ec_key, nil
}
type AuthRespJSON struct {
Granted time.Time `json:"granted"`
ECDHPubkey []byte `json:"echd_server"`
}
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, error) {
// Check if req.Time is within +- 1 second of now
now := time.Now()
earliest := now.Add(-1 * time.Second)
latest := now.Add(1 * time.Second)
// If req.Time is before the earliest acceptable time, or after the latest acceptible time
if req.Time.Compare(earliest) == -1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time)
} else if req.Time.Compare(latest) == 1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
}
x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey)
if x == nil {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
}
remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil {
return AuthRespJSON{}, nil, err
}
// Verify the signature
time_bytes, _ := req.Time.MarshalJSON()
sig_data := append(req.ECDHPubkey, time_bytes...)
sig_hash := sha512.Sum512(sig_data)
verified := ecdsa.VerifyASN1(
&ecdsa.PublicKey{
Curve: thread.Key.Curve,
X: x,
Y: y,
},
sig_hash[:],
req.Signature,
)
if verified == false {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
}
ec_key, err := thread.ECDH.GenerateKey(rand.Reader)
if err != nil {
return AuthRespJSON{}, nil, err
}
shared_secret, err := ec_key.ECDH(remote)
if err != nil {
return AuthRespJSON{}, nil, err
}
return AuthRespJSON{
Granted: time.Now(),
ECDHPubkey: ec_key.PublicKey().Bytes(),
}, shared_secret, nil
}
type AuthData struct {
Granted time.Time
Pubkey ecdh.PublicKey
ECDHClient ecdh.PublicKey
}
type AuthDataJSON struct {
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pbkey"`
ECDHClient []byte `json:"ecdh_client"`
}
func HashKey(pub []byte) uint64 {
return 0
}
func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
str, err := io.ReadAll(r.Body)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_READ_ERR: %e", err)
return
}
var req AuthReqJSON
err = json.Unmarshal([]byte(str), &req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_UNMARHSHAL_ERR: %e", err)
return
}
resp, _, err := NewAuthRespJSON(server, req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err)
return
}
ser, err := json.Marshal(resp)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_MARSHAL_ERR: %e", err)
return
}
wrote, err := w.Write(ser)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_ERR: %e", err)
return
} else if wrote != len(ser) {
ctx.Log.Logf("gql", "GQL_AUTH_RESP_BAD_LENGTH: %d/%d", wrote, len(ser))
return
}
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str)
key_hash := HashKey(req.Pubkey)
_, exists := server.AuthMap[key_hash]
if exists {
// New user
} else {
// Existing user
}
}
}
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
graphiql_string := fmt.Sprintf(`
@ -166,7 +339,6 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
auth, ok := checkForAuthHeader(r.Header)
if ok == false {
ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header")
json.NewEncoder(w).Encode(GQLUnauthorized("No TM Auth header provided"))
return
@ -388,9 +560,13 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
type GQLThread struct {
SimpleThread
tcp_listener net.Listener
http_server *http.Server
http_done *sync.WaitGroup
Listen string
AuthMap map[uint64]AuthData
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
}
func (thread * GQLThread) Type() NodeType {
@ -414,14 +590,41 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo
type GQLThreadJSON struct {
SimpleThreadJSON
Listen string `json:"listen"`
AuthMap map[uint64]AuthData `json:"auth_map"`
Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"`
}
var ecdsa_curves = map[uint8]elliptic.Curve{
0: elliptic.P256(),
}
var ecdsa_curve_ids = map[elliptic.Curve]uint8{
elliptic.P256(): 0,
}
var ecdh_curves = map[uint8]ecdh.Curve{
0: ecdh.P256(),
}
var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0,
}
func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
thread_json := NewSimpleThreadJSON(&thread.SimpleThread)
ser_key, err := x509.MarshalECPrivateKey(thread.Key)
if err != nil {
panic(err)
}
return GQLThreadJSON{
SimpleThreadJSON: thread_json,
Listen: thread.Listen,
AuthMap: thread.AuthMap,
Key: ser_key,
ECDH: ecdh_curve_ids[thread.ECDH],
}
}
@ -432,7 +635,18 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
return nil, err
}
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen)
ecdh_curve, ok := ecdh_curves[j.ECDH]
if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
}
key, err := x509.ParseECPrivateKey(j.Key)
if err != nil {
return nil, err
}
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key)
thread.AuthMap = j.AuthMap
nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes)
@ -443,11 +657,14 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
return &thread, nil
}
func NewGQLThread(id NodeID, name string, state_name string, listen string) GQLThread {
func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey) GQLThread {
return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers),
Listen: listen,
AuthMap: map[uint64]AuthData{},
http_done: &sync.WaitGroup{},
Key: key,
ECDH: ecdh_curve,
}
}
@ -477,6 +694,7 @@ var gql_actions ThreadActions = ThreadActions{
ctx.Log.Logf("gql", "GQL_START_SERVER")
// Serve the GQL http and ws handlers
mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, server))
mux.HandleFunc("/gql", GQLHandler(ctx, server))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server))
@ -487,23 +705,34 @@ var gql_actions ThreadActions = ThreadActions{
fs := http.FileServer(http.Dir("./site"))
mux.Handle("/site/", http.StripPrefix("/site", fs))
UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){
server.http_server = &http.Server{
http_server := &http.Server{
Addr: server.Listen,
Handler: mux,
}
return nil
})
listener, err := net.Listen("tcp", http_server.Addr)
if err != nil {
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
server.http_done.Add(1)
go func(server *GQLThread) {
defer server.http_done.Done()
err := server.http_server.ListenAndServe()
err = http_server.Serve(listener)
if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(server)
UseStates(ctx, []Node{server}, func(nodes NodeMap)(error){
server.tcp_listener = listener
server.http_server = http_server
return server.Signal(ctx, NewSignal(server, "server_started"), nodes)
})
return "wait", nil
},
}

@ -3,13 +3,24 @@ package graphvent
import (
"testing"
"time"
"fmt"
"errors"
"net"
"net/http"
"io"
"fmt"
"encoding/json"
"bytes"
"crypto/rand"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
)
func TestGQLThread(t * testing.T) {
ctx := logTestContext(t, []string{})
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0")
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key)
gql_t := &gql_t_r
t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
@ -17,7 +28,7 @@ func TestGQLThread(t * testing.T) {
t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers)
t2 := &t2_r
err := UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error {
err = UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error {
i1 := NewParentThreadInfo(true, "start", "restore")
err := LinkThreads(ctx, gql_t, t1, &i1, nodes)
if err != nil {
@ -42,7 +53,7 @@ func TestGQLThread(t * testing.T) {
}
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"thread", "signal", "gql", "test"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -50,11 +61,13 @@ func TestGQLDBLoad(t * testing.T) {
t1 := &t1_r
update_channel := UpdateChannel(t1, 10, "test")
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":8080")
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key)
gql := &gql_r
info := NewParentThreadInfo(true, "start", "restore")
err := UpdateStates(ctx, []Node{gql, t1, l1}, func(nodes NodeMap) error {
err = UpdateStates(ctx, []Node{gql, t1, l1}, func(nodes NodeMap) error {
err := LinkLockables(ctx, gql, []Lockable{l1}, nodes)
if err != nil {
return err
@ -84,8 +97,8 @@ func TestGQLDBLoad(t * testing.T) {
err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error {
ser1, err := gql.Serialize()
ser2, err := t1.Serialize()
fmt.Printf("\n%s\n\n", ser1)
fmt.Printf("\n%s\n\n", ser2)
ctx.Log.Logf("thread", "\n%s\n\n", ser1)
ctx.Log.Logf("thread", "\n%s\n\n", ser2)
return err
})
@ -96,13 +109,13 @@ func TestGQLDBLoad(t * testing.T) {
var update_channel_2 chan GraphSignal
err = UseStates(ctx, []Node{gql_loaded}, func(nodes NodeMap) error {
ser, err := gql_loaded.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, "test")
err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error {
ser, err := child.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
gql_loaded.Signal(ctx, AbortSignal(nil), nodes)
@ -118,3 +131,65 @@ func TestGQLDBLoad(t * testing.T) {
(*GraphTester)(t).WaitForValue(ctx, update_channel_2, "thread_aborted", t1_loaded, 100*time.Millisecond, "Didn't received thread_aborted on t1_loaded from t1_loaded")
}
func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"test", "gql"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key)
gql_t := &gql_t_r
done := make(chan error, 1)
var update_channel chan GraphSignal
err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error {
update_channel = UpdateChannel(gql_t, 10, "test")
return nil
})
fatalErr(t, err)
go func(done chan error, thread Thread) {
timeout := time.After(2*time.Second)
select {
case <-timeout:
ctx.Log.Logf("test", "TIMEOUT")
case <-done:
ctx.Log.Logf("test", "DONE")
}
err := UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error {
return thread.Signal(ctx, CancelSignal(nil), nodes)
})
fatalErr(t, err)
}(done, gql_t)
go func(thread Thread){
(*GraphTester)(t).WaitForValue(ctx, update_channel, "server_started", gql_t, 100*time.Millisecond, "Server didn't start")
port := gql_t.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port)
client := &http.Client{}
url := fmt.Sprintf("http://localhost:%d/auth", port)
id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
auth_req, _, err := NewAuthReqJSON(ecdh.P256(), id)
fatalErr(t, err)
str, err := json.Marshal(auth_req)
fatalErr(t, err)
b := bytes.NewBuffer(str)
req, err := http.NewRequest("PUT", url, b)
fatalErr(t, err)
req.Header.Add("Authorization", "TM baddata")
resp, err := client.Do(req)
fatalErr(t, err)
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
fatalErr(t, err)
ctx.Log.Logf("test", "RESP_BODY: %s", body)
done <- nil
}(gql_t)
err = ThreadLoop(ctx, gql_t, "start")
fatalErr(t, err)
}

@ -298,7 +298,7 @@ func TestLockableLockTieredConflict(t * testing.T) {
}
func TestLockableSimpleUpdate(t * testing.T) {
ctx := logTestContext(t, []string{"test", "update", "lockable"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -316,7 +316,7 @@ func TestLockableSimpleUpdate(t * testing.T) {
}
func TestLockableDownUpdate(t * testing.T) {
ctx := logTestContext(t, []string{"test", "update", "lockable"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -345,7 +345,7 @@ func TestLockableDownUpdate(t * testing.T) {
}
func TestLockableUpUpdate(t * testing.T) {
ctx := logTestContext(t, []string{"test", "update", "lockable"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -374,7 +374,7 @@ func TestLockableUpUpdate(t * testing.T) {
}
func TestOwnerNotUpdatedTwice(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "lockable", "listeners"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
@ -461,7 +461,7 @@ func TestLockableDBLoad(t * testing.T){
err = UseStates(ctx, []Node{l3}, func(nodes NodeMap) error {
ser, err := l3.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
fatalErr(t, err)
@ -472,14 +472,14 @@ func TestLockableDBLoad(t * testing.T){
// TODO: add more equivalence checks
err = UseStates(ctx, []Node{l3_loaded}, func(nodes NodeMap) error {
ser, err := l3_loaded.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
fatalErr(t, err)
}
func TestLockableUnlink(t * testing.T){
ctx := logTestContext(t, []string{"lockable"})
ctx := logTestContext(t, []string{})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
l2_r := NewSimpleLockable(RandID(), "Test Lockable 2")

@ -87,7 +87,7 @@ func TestThreadDBLoad(t * testing.T) {
err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error {
ser, err := t1.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
@ -96,7 +96,7 @@ func TestThreadDBLoad(t * testing.T) {
err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error {
ser, err := t1_loaded.Serialize()
fmt.Printf("\n%s\n\n", ser)
ctx.Log.Logf("test", "\n%s\n\n", ser)
return err
})
}