diff --git a/.gitignore b/.gitignore index adf8f72..30dac79 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,6 @@ # Go workspace file go.work +.venv +key +key.pub diff --git a/cmd/genkey/main.go b/cmd/genkey/main.go new file mode 100644 index 0000000..629c020 --- /dev/null +++ b/cmd/genkey/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "os" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "fmt" +) + +func usage(code int, err error) { + if err != nil { + fmt.Fprintf(os.Stderr, "error: %s\n", err) + } + fmt.Fprintf(os.Stderr, "usage: %s [output_path]\n", os.Args[0]) + os.Exit(code) +} + +func main() { + if len(os.Args) != 2 { + usage(0, nil) + } + output_path := os.Args[1] + + private_file, err := os.Create(output_path) + if err != nil { + usage(1, err) + } + + public_file, err := os.Create(fmt.Sprintf("%s.pub", output_path)) + if err != nil { + usage(2, err) + } + + public, private, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + usage(3, err) + } + + public_bytes, err := x509.MarshalPKIXPublicKey(public) + if err != nil { + usage(4, err) + } + + private_bytes, err := x509.MarshalPKCS8PrivateKey(private) + if err != nil { + usage(5, err) + } + + err = pem.Encode(public_file, &pem.Block{ + Type: "PUBLIC KEY", + Bytes: public_bytes, + }) + + err = pem.Encode(private_file, &pem.Block{ + Type: "PRIVATE KEY", + Bytes: private_bytes, + }) + err = public_file.Close() + if err != nil { + usage(6, err) + } + + err = private_file.Close() + if err != nil { + usage(7, err) + } + fmt.Printf("Wrote %s and %s.pub\n", output_path, output_path) +} + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0a7dbe7 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module git.metznet.ca/MetzNet/score_server + +go 1.21.4 + +require ( + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/jba/muxpatterns v0.3.0 // indirect + github.com/lib/pq v1.10.9 // indirect + golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6f8717a --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/jba/muxpatterns v0.3.0 h1:4mumR8LmA/7vUigN/BmUXtHXbwZOzggVuX9Dgqtw67c= +github.com/jba/muxpatterns v0.3.0/go.mod h1:77+op56At17SXLuQrR46FWJybINlMcj2E/KrDhS5JiY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/exp v0.0.0-20230519143937-03e91628a987 h1:3xJIFvzUFbu4ls0BTBYcgbCGhA63eAOEMxIHugyXJqA= +golang.org/x/exp v0.0.0-20230519143937-03e91628a987/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= diff --git a/main.go b/main.go new file mode 100644 index 0000000..53e2dc6 --- /dev/null +++ b/main.go @@ -0,0 +1,181 @@ +package main + +import ( + "github.com/golang-jwt/jwt" + "net/http" + "crypto/ed25519" + "github.com/jba/muxpatterns" + "database/sql" + _ "github.com/lib/pq" + "encoding/json" + "fmt" + "strings" + "os" + "encoding/pem" + "crypto/x509" + "reflect" +) + +type ErrorJSON struct { + Error string `json:"error"` +} + +type DBHandler struct { + Audience string + Resource string + VerifyKey ed25519.PublicKey + DBConnection *sql.DB + DBOptions *sql.TxOptions + HandleFunc func(*sql.Tx, http.ResponseWriter, *http.Request) ([]byte, error) +} + +type Claims struct { + jwt.StandardClaims + Scope []string `json:"scope,omitempty"` +} + +func LoadKey(pem_string string) (ed25519.PublicKey, error) { + pem_block, rest := pem.Decode([]byte(pem_string)) + if pem_block == nil { + return nil, fmt.Errorf("Unable to decode pem from %s", pem_string) + } + + pubkey_if, err := x509.ParsePKIXPublicKey(pem_block.Bytes) + if err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, fmt.Errorf("%d bytes remaining in key after decode", len(rest)) + } + + pubkey, right_type := pubkey_if.(ed25519.PublicKey) + if right_type == false { + return nil, fmt.Errorf("%+v is not ed25519.PublicKey", reflect.TypeOf(pubkey_if)) + } + + return pubkey, nil +} + + +func writeError(w http.ResponseWriter, err error) { + error_string := err.Error() + + data, _ := json.MarshalIndent(ErrorJSON{error_string}, " ", " ") + + w.WriteHeader(500) + w.Write(data) +} + +func writeErrorRollback(tx *sql.Tx, w http.ResponseWriter, err error) { + rollback_err := tx.Rollback() + if rollback_err != nil { + err = fmt.Errorf("Rollback Error while writing error response %e - %e", rollback_err, err) + } + writeError(w, err) +} + +func (handler DBHandler) VerifyToken(r *http.Request) (Claims, error) { + auth_header := r.Header.Get("Authorization") + parts := strings.Split(auth_header, " ") + if len(parts) != 2 { + return Claims{}, fmt.Errorf("Not enough values in Authorization header") + } else if parts[0] != "Bearer" { + return Claims{}, fmt.Errorf("Wrong Authorization type(%s)", parts[0]) + } + + var claims Claims + _, err := jwt.ParseWithClaims(parts[1], &claims, func(token *jwt.Token) (interface{}, error) { + return handler.VerifyKey, nil + }) + + if err != nil { + return Claims{}, err + } + + if claims.Audience != handler.Audience { + return claims, fmt.Errorf("Server %s not in audience %+v", handler.Audience, claims.Audience) + } + + resource_allowed := false + for _, r := range(claims.Scope) { + if handler.Resource == r { + resource_allowed = true + break + } + } + + if resource_allowed == false { + return claims, fmt.Errorf("Resource %s not in scope %+v", handler.Resource, claims.Scope) + } + + return claims, nil +} + +func (handler DBHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + _, err := handler.VerifyToken(r) + if err != nil { + writeError(w, err) + return + } + + tx, err := handler.DBConnection.BeginTx(r.Context(), handler.DBOptions) + if err != nil { + writeError(w, err) + return + } + + if handler.HandleFunc == nil { + writeError(w, fmt.Errorf("NOT_IMPLEMENTED")) + return + } + + resp, err := handler.HandleFunc(tx, w, r) + if err != nil { + writeErrorRollback(tx, w, err) + return + } + + w.WriteHeader(200) + w.Write(resp) +} + +func main() { + if len(os.Args) != 3 { + fmt.Printf("usage: %s {db_string} {public_key_pem}", os.Args[0]) + os.Exit(1) + } + + db_string := os.Args[1] + db, err := sql.Open("postgres", db_string) + if err != nil { + panic(err) + } + + verify_pubkey, err := LoadKey(os.Args[2]) + if err != nil { + panic(err) + } + + + mux := muxpatterns.NewServeMux() + mux.Handle("GET /tournaments", DBHandler{ + Audience: "score_server", + Resource: "tournaments:list", + VerifyKey: verify_pubkey, + DBConnection: db, + DBOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: false}, + HandleFunc: func(tx *sql.Tx, w http.ResponseWriter, r *http.Request) ([]byte, error) { + return []byte(`{"response": "success"}`), nil + }, + }) + + addr := ":8080" + + server := &http.Server{ + Addr: addr, + Handler: mux, + } + + fmt.Printf("Starting server on %s\n", addr) + server.ListenAndServe() +} diff --git a/scripts/test.py b/scripts/test.py new file mode 100755 index 0000000..d8c426d --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import jwt +from sys import argv +from datetime import datetime +from time import mktime +import requests + +def gen_jwt(aud, scope, key): + return jwt.encode({"aud": aud, "scope": scope, "issued": mktime(datetime.now().timetuple())}, pem_string, algorithm="EdDSA") + +def usage(): + print(f"Usage {argv[0]} {{pem_file}} {{command}}") + +if __name__ == "__main__": + if len(argv) < 3: + usage() + exit(1) + + pem_string = None + with open(argv[1]) as f: + pem_string = "".join(f.readlines()) + + if pem_string is None: + print(f"Couldn't read pem file {argv[1]}") + exit(2) + + token = gen_jwt("score_server", ["tournaments:list"], pem_string) + headers = {'Authorization': f"Bearer {token}"} + + response = None + match argv[2]: + case "list-tournaments": + response = requests.get("http://localhost:8080/tournaments", headers=headers) + case _: + usage() + exit(3) + + if response is not None: + print(response) + print(response.text) + else: + print("unhandled_error") +