diff --git a/context.go b/context.go index ba0a4c1..d1ac772 100644 --- a/context.go +++ b/context.go @@ -185,6 +185,11 @@ func NewContext(db * badger.DB, log Logger) * Context { panic(err) } + err = ctx.RegisterNodeType(NewNodeDef((*GQLUser)(nil), LoadGQLUser, GQLTypeGQLUser())) + if err != nil { + panic(err) + } + ctx.AddGQLType(GQLTypeSignal()) ctx.GQL.Query.AddFieldConfig("Self", GQLQuerySelf()) diff --git a/gql.go b/gql.go index c446d12..d834f8e 100644 --- a/gql.go +++ b/gql.go @@ -10,6 +10,7 @@ import ( "github.com/graphql-go/graphql/language/ast" "context" "encoding/json" + "encoding/base64" "io" "reflect" "fmt" @@ -23,7 +24,6 @@ import ( "crypto/sha512" "crypto/rand" "crypto/x509" - "bytes" "github.com/google/uuid" ) @@ -67,7 +67,7 @@ type AuthRespJSON struct { Signature []byte `json:"signature"` } -func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, *ecdsa.PrivateKey, error) { +func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { // Check if req.Time is within +- 1 second of now now := time.Now() earliest := now.Add(-1 * time.Second) @@ -132,21 +132,14 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P return AuthRespJSON{}, nil, nil, err } - secret_hash := sha512.Sum512(shared_secret) - buf := bytes.NewReader(secret_hash[:]) - shared_key, err := ecdsa.GenerateKey(thread.Key.Curve, buf) - if err != nil { - return AuthRespJSON{}, nil, nil, err - } - return AuthRespJSON{ Granted: granted, ECDHPubkey: ec_key_pub, Signature: resp_sig, - }, remote_key, shared_key, nil + }, remote_key, shared_secret, nil } -func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve ecdh.Curve, ec_key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { +func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve ecdh.Curve, ec_key *ecdh.PrivateKey) ([]byte, error) { remote, err := ecdh_curve.NewPublicKey(resp.ECDHPubkey) if err != nil { return nil, err @@ -157,27 +150,19 @@ func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve return nil, err } - secret_hash := sha512.Sum512(shared_secret) - buf := bytes.NewReader(secret_hash[:]) - shared_key, err := ecdsa.GenerateKey(ecdsa_curve, buf) - if err != nil { - return nil, err - } - - return shared_key, nil + return shared_secret, nil } -type AuthData struct { +type GQLUser struct { + SimpleLockable + Granted time.Time Pubkey *ecdsa.PublicKey - Shared *ecdsa.PrivateKey -} - -func (data AuthData) String() string { - return fmt.Sprintf("{Granted: %+v, Pubkey: %s, Shared: %s}", data.Granted, KeyID(data.Pubkey).String(), KeyID(&data.Shared.PublicKey).String()) + Shared []byte } -type AuthDataJSON struct { +type GQLUserJSON struct { + SimpleLockableJSON Granted time.Time `json:"granted"` Pubkey []byte `json:"pubkey"` Shared []byte `json:"shared"` @@ -189,6 +174,66 @@ func KeyID(pub *ecdsa.PublicKey) NodeID { return NodeID(str) } +func (user *GQLUser) Type() NodeType { + return NodeType("gql_user") +} + +func (user *GQLUser) Serialize() ([]byte, error) { + lockable_json := NewSimpleLockableJSON(&user.SimpleLockable) + pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey) + if err != nil { + return nil, err + } + + return json.MarshalIndent(&GQLUserJSON{ + SimpleLockableJSON: lockable_json, + Granted: user.Granted, + Shared: user.Shared, + Pubkey: pubkey, + }, "", " ") +} + +func LoadGQLUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j GQLUserJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + pub, err := x509.ParsePKIXPublicKey(j.Pubkey) + if err != nil { + return nil, err + } + + var pubkey *ecdsa.PublicKey + switch pub.(type) { + case *ecdsa.PublicKey: + pubkey = pub.(*ecdsa.PublicKey) + default: + return nil, fmt.Errorf("Invalid key type") + } + + user := NewGQLUser(j.Name, j.Granted, pubkey, j.Shared) + nodes[id] = &user + + err = RestoreSimpleLockable(ctx, &user, j.SimpleLockableJSON, nodes) + if err != nil { + return nil, err + } + + return &user, nil +} + +func NewGQLUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte) GQLUser { + id := KeyID(pubkey) + return GQLUser{ + SimpleLockable: NewSimpleLockable(id, name), + Granted: granted, + Pubkey: pubkey, + Shared: shared, + } +} + func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr) @@ -207,7 +252,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt return } - resp, remote_id, shared_key, err := NewAuthRespJSON(server, req) + resp, remote_id, shared, err := NewAuthRespJSON(server, req) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err) return @@ -230,20 +275,23 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt key_id := KeyID(remote_id) - new_auth := AuthData{ - Granted: time.Now(), - Pubkey: remote_id, - Shared: shared_key, - } - - _, exists := server.AuthMap[key_id] + _, exists := server.Users[key_id] if exists { - ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s - %s", key_id, new_auth) + ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s", key_id) } else { - ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, new_auth) + ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) + + new_user := NewGQLUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared) + err := UpdateStates(ctx, []Node{server}, func(nodes NodeMap) error { + server.Users[key_id] = &new_user + return nil + }) + if err != nil { + ctx.Log.Logf("gql", "GQL_AUTH_UPDATE_ERR: %s", err) + return + } } - server.AuthMap[key_id] = new_auth } } @@ -322,7 +370,7 @@ func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { } -type GQLWSPayload struct { +type GQLPayload struct { OperationName string `json:"operationName,omitempty"` Query string `json:"query,omitempty"` Variables map[string]interface{} `json:"variables,omitempty"` @@ -333,7 +381,7 @@ type GQLWSPayload struct { type GQLWSMsg struct { ID string `json:"id,omitempty"` Type string `json:"type"` - Payload GQLWSPayload `json:"payload,omitempty"` + Payload GQLPayload `json:"payload,omitempty"` } func enableCORS(w *http.ResponseWriter) { @@ -381,6 +429,29 @@ func checkForAuthHeader(header http.Header) (string, bool) { return "", false } +func CheckAuth(server *GQLThread, r *http.Request) (*GQLUser, error) { + username, password, ok := r.BasicAuth() + if ok == false { + return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") + } + + auth_id, err := ParseID(username) + if err != nil { + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) + } + + user, exists := server.Users[auth_id] + if exists == false { + return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) + } + + if base64.StdEncoding.EncodeToString(user.Shared) != password { + return nil, fmt.Errorf("GQL_AUTH_FAIL") + } + + return user, nil +} + func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { gql_ctx := context.Background() gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) @@ -394,36 +465,22 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht header_map[header] = value } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - username, password, ok := r.BasicAuth() - if ok == false { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header") - json.NewEncoder(w).Encode(GQLUnauthorized("No Auth header provided")) - return - } - auth_id, err := ParseID(username) + user, err := CheckAuth(server, r) if err != nil { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) - json.NewEncoder(w).Encode(GQLUnauthorized("Failed to parse ID from username")) + ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) + json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) return } - - auth, exists := server.AuthMap[auth_id] - if exists == false { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) - json.NewEncoder(w).Encode(GQLUnauthorized("No matching authorization for client")) - return - } - - req_ctx := context.WithValue(gql_ctx, "auth", auth) - ctx.Log.Logf("gql", "GQL_AUTH: %+v - %s", auth, password) + req_ctx := context.WithValue(gql_ctx, "user", user) str, err := io.ReadAll(r.Body) if err != nil { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: failed to read request body: %s", err) + ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err) + json.NewEncoder(w).Encode(fmt.Sprintf("%e", err)) return } - query := GQLWSPayload{} + query := GQLPayload{} json.Unmarshal(str, &query) params := graphql.Params{ @@ -509,29 +566,12 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - username, password, ok := r.BasicAuth() - if ok == false { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header") - json.NewEncoder(w).Encode(GQLUnauthorized("No Auth header provided")) - return - } - - auth_id, err := ParseID(username) + user, err := CheckAuth(server, r) if err != nil { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) - json.NewEncoder(w).Encode(GQLUnauthorized("Failed to parse ID from username")) - return - } - - auth, exists := server.AuthMap[auth_id] - if exists == false { - ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) - json.NewEncoder(w).Encode(GQLUnauthorized("No matching authorization for client")) + ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) return } - - req_ctx := context.WithValue(gql_ctx, "auth", auth) - ctx.Log.Logf("gql", "GQL_AUTH: %+v - %s", auth, password) + req_ctx := context.WithValue(gql_ctx, "user", user) u := ws.HTTPUpgrader{ Protocol: func(protocol string) bool { @@ -619,7 +659,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * msg, err := json.Marshal(GQLWSMsg{ ID: msg.ID, Type: "next", - Payload: GQLWSPayload{ + Payload: GQLPayload{ Data: string(data), }, }) @@ -651,7 +691,7 @@ type GQLThread struct { http_server *http.Server http_done *sync.WaitGroup Listen string - AuthMap map[NodeID]AuthData + Users map[NodeID]*GQLUser Key *ecdsa.PrivateKey ECDH ecdh.Curve } @@ -677,7 +717,7 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo type GQLThreadJSON struct { SimpleThreadJSON Listen string `json:"listen"` - AuthMap map[string]AuthDataJSON `json:"auth_map"` + Users []string `json:"users"` Key []byte `json:"key"` ECDH uint8 `json:"ecdh_curve"` } @@ -706,23 +746,17 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { panic(err) } - auth_map := map[string]AuthDataJSON{} - for id, data := range(thread.AuthMap) { - shared, err := x509.MarshalECPrivateKey(data.Shared) - if err != nil { - panic(err) - } - auth_map[id.String()] = AuthDataJSON{ - Granted: data.Granted, - Pubkey: elliptic.Marshal(data.Pubkey.Curve, data.Pubkey.X, data.Pubkey.Y), - Shared: shared, - } + users := make([]string, len(thread.Users)) + i := 0 + for id, _ := range(thread.Users) { + users[i] = id.String() + i += 1 } return GQLThreadJSON{ SimpleThreadJSON: thread_json, Listen: thread.Listen, - AuthMap: auth_map, + Users: users, Key: ser_key, ECDH: ecdh_curve_ids[thread.ECDH], } @@ -746,29 +780,17 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e } thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) - thread.AuthMap = map[NodeID]AuthData{} - for id_str, auth_json := range(j.AuthMap) { + thread.Users = map[NodeID]*GQLUser{} + for _, id_str := range(j.Users) { id, err := ParseID(id_str) if err != nil { return nil, err } - x, y := elliptic.Unmarshal(key.Curve, auth_json.Pubkey) - if x == nil { - return nil, fmt.Errorf("Failed to load public key for curve %+v from %+v", key.Curve, auth_json.Pubkey) - } - shared, err := x509.ParseECPrivateKey(auth_json.Shared) + user, err := LoadNodeRecurse(ctx, id, nodes) if err != nil { return nil, err } - thread.AuthMap[id] = AuthData{ - Granted: auth_json.Granted, - Pubkey: &ecdsa.PublicKey{ - Curve: key.Curve, - X: x, - Y: y, - }, - Shared: shared, - } + thread.Users[id] = user.(*GQLUser) } nodes[id] = &thread @@ -784,7 +806,7 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh return GQLThread{ SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Listen: listen, - AuthMap: map[NodeID]AuthData{}, + Users: map[NodeID]*GQLUser{}, http_done: &sync.WaitGroup{}, Key: key, ECDH: ecdh_curve, @@ -863,7 +885,7 @@ var gql_actions ThreadActions = ThreadActions{ var gql_handlers ThreadHandlers = ThreadHandlers{ "child_added": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) - UpdateStates(ctx, []Node{thread}, func(nodes NodeMap)(error) { + UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { should_run, exists := thread.ChildInfo(signal.Source()).(*ParentThreadInfo) if exists == false { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") diff --git a/gql_graph.go b/gql_graph.go index 538e71e..a86562c 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -385,6 +385,60 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { return owner, nil } +var gql_type_gql_user *graphql.Object = nil +func GQLTypeGQLUser() * graphql.Object { + if gql_type_gql_user == nil { + gql_type_gql_user = graphql.NewObject(graphql.ObjectConfig{ + Name: "GQLUser", + Interfaces: []*graphql.Interface{ + GQLInterfaceNode(), + GQLInterfaceLockable(), + }, + IsTypeOf: func(p graphql.IsTypeOfParams) bool { + ctx, ok := p.Context.Value("graph_context").(*Context) + if ok == false { + return false + } + + lockable_type := ctx.GQL.LockableType + value_type := reflect.TypeOf(p.Value) + + if value_type.Implements(lockable_type) { + return true + } + + return false + }, + Fields: graphql.Fields{}, + }) + + gql_type_gql_user.AddFieldConfig("ID", &graphql.Field{ + Type: graphql.String, + Resolve: GQLNodeID, + }) + + gql_type_gql_user.AddFieldConfig("Name", &graphql.Field{ + Type: graphql.String, + Resolve: GQLLockableName, + }) + + gql_type_gql_user.AddFieldConfig("Requirements", &graphql.Field{ + Type: GQLListLockable(), + Resolve: GQLLockableRequirements, + }) + + gql_type_gql_user.AddFieldConfig("Owner", &graphql.Field{ + Type: GQLInterfaceLockable(), + Resolve: GQLLockableOwner, + }) + + gql_type_gql_user.AddFieldConfig("Dependencies", &graphql.Field{ + Type: GQLListLockable(), + Resolve: GQLLockableDependencies, + }) + } + return gql_type_gql_user +} var gql_type_gql_thread *graphql.Object = nil func GQLTypeGQLThread() * graphql.Object { diff --git a/gql_test.go b/gql_test.go index fd1e556..079f64a 100644 --- a/gql_test.go +++ b/gql_test.go @@ -14,6 +14,7 @@ import ( "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" + "encoding/base64" ) func TestGQLThread(t * testing.T) { @@ -61,13 +62,22 @@ func TestGQLDBLoad(t * testing.T) { t1 := &t1_r update_channel := UpdateChannel(t1, 10, NodeID{}) + u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + fatalErr(t, err) + + u1_shared := []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67} + + u1_r := NewGQLUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared) + u1 := &u1_r + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) gql_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key) gql := &gql_r info := NewParentThreadInfo(true, "start", "restore") - err = UpdateStates(ctx, []Node{gql, t1, l1}, func(nodes NodeMap) error { + err = UpdateStates(ctx, []Node{gql, t1, l1, u1}, func(nodes NodeMap) error { + gql.Users[KeyID(&u1_key.PublicKey)] = u1 err := LinkLockables(ctx, gql, []Lockable{l1}, nodes) if err != nil { return err @@ -177,23 +187,57 @@ func TestGQLAuth(t * testing.T) { str, err := json.Marshal(auth_req) fatalErr(t, err) + b := bytes.NewBuffer(str) req, err := http.NewRequest("PUT", url, b) fatalErr(t, err) - req.Header.Add("Authorization", "TM baddata") + resp, err := client.Do(req) fatalErr(t, err) + body, err := io.ReadAll(resp.Body) - resp.Body.Close() fatalErr(t, err) + resp.Body.Close() + var j AuthRespJSON err = json.Unmarshal(body, &j) fatalErr(t, err) - shared_key, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key) + shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key) + fatalErr(t, err) + + url = fmt.Sprintf("http://localhost:%d/gql", port) + ser, err := json.MarshalIndent(&GQLPayload{ + Query: "query { Self { ID } }", + }, "", " ") + fatalErr(t, err) + + b = bytes.NewBuffer(ser) + req, err = http.NewRequest("GET", url, b) fatalErr(t, err) - ctx.Log.Logf("test", "TEST_SHARED_SECRET: %s", KeyID(&shared_key.PublicKey).String()) + + req.SetBasicAuth(KeyID(&id.PublicKey).String(), base64.StdEncoding.EncodeToString(shared)) + resp, err = client.Do(req) + fatalErr(t, err) + + body, err = io.ReadAll(resp.Body) + fatalErr(t, err) + + resp.Body.Close() + + ctx.Log.Logf("test", "TEST_RESP: %s", body) + + req.SetBasicAuth(KeyID(&id.PublicKey).String(), "BAD_PASSWORD") + resp, err = client.Do(req) + fatalErr(t, err) + + body, err = io.ReadAll(resp.Body) + fatalErr(t, err) + + resp.Body.Close() + + ctx.Log.Logf("test", "TEST_RESP: %s", body) done <- nil }(gql_t) diff --git a/graph_test.go b/graph_test.go index 6c68a31..5ab8c88 100644 --- a/graph_test.go +++ b/graph_test.go @@ -72,6 +72,7 @@ func testContext(t * testing.T) * Context { func fatalErr(t * testing.T, err error) { if err != nil { + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) t.Fatal(err) } } diff --git a/thread.go b/thread.go index b2a207b..b7b74ab 100644 --- a/thread.go +++ b/thread.go @@ -355,7 +355,7 @@ func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node return &thread, nil } -// SimpleThread as no associated info with children +// SimpleThread has no associated info with children func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) { if len(data) > 0 { return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data))