package main import ( "github.com/golang-jwt/jwt" "net/http" "crypto/ed25519" "github.com/jba/muxpatterns" "database/sql" _ "github.com/lib/pq" "encoding/json" "fmt" "strings" "os" "encoding/pem" "crypto/x509" "reflect" ) type ErrorJSON struct { Error string `json:"error"` } type DBHandler struct { Audience string Resource string VerifyKey ed25519.PublicKey DBConnection *sql.DB DBOptions *sql.TxOptions HandleFunc func(*sql.Tx, http.ResponseWriter, *http.Request) ([]byte, error) } type Claims struct { jwt.StandardClaims Scope []string `json:"scope,omitempty"` } func LoadKey(pem_string string) (ed25519.PublicKey, error) { pem_block, rest := pem.Decode([]byte(pem_string)) if pem_block == nil { return nil, fmt.Errorf("Unable to decode pem from %s", pem_string) } pubkey_if, err := x509.ParsePKIXPublicKey(pem_block.Bytes) if err != nil { return nil, err } else if len(rest) > 0 { return nil, fmt.Errorf("%d bytes remaining in key after decode", len(rest)) } pubkey, right_type := pubkey_if.(ed25519.PublicKey) if right_type == false { return nil, fmt.Errorf("%+v is not ed25519.PublicKey", reflect.TypeOf(pubkey_if)) } return pubkey, nil } func writeError(w http.ResponseWriter, err error) { error_string := err.Error() data, _ := json.MarshalIndent(ErrorJSON{error_string}, " ", " ") w.WriteHeader(500) w.Write(data) } func writeErrorRollback(tx *sql.Tx, w http.ResponseWriter, err error) { rollback_err := tx.Rollback() if rollback_err != nil { err = fmt.Errorf("Rollback Error while writing error response %e - %e", rollback_err, err) } writeError(w, err) } func (handler DBHandler) VerifyToken(r *http.Request) (Claims, error) { auth_header := r.Header.Get("Authorization") parts := strings.Split(auth_header, " ") if len(parts) != 2 { return Claims{}, fmt.Errorf("Not enough values in Authorization header") } else if parts[0] != "Bearer" { return Claims{}, fmt.Errorf("Wrong Authorization type(%s)", parts[0]) } var claims Claims _, err := jwt.ParseWithClaims(parts[1], &claims, func(token *jwt.Token) (interface{}, error) { return handler.VerifyKey, nil }) if err != nil { return Claims{}, err } if claims.Audience != handler.Audience { return claims, fmt.Errorf("Server %s not in audience %+v", handler.Audience, claims.Audience) } resource_allowed := false for _, r := range(claims.Scope) { if handler.Resource == r { resource_allowed = true break } } if resource_allowed == false { return claims, fmt.Errorf("Resource %s not in scope %+v", handler.Resource, claims.Scope) } return claims, nil } func (handler DBHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("content-type", "application/json") _, err := handler.VerifyToken(r) if err != nil { writeError(w, err) return } tx, err := handler.DBConnection.BeginTx(r.Context(), handler.DBOptions) if err != nil { writeError(w, err) return } if handler.HandleFunc == nil { writeError(w, fmt.Errorf("NOT_IMPLEMENTED")) return } resp, err := handler.HandleFunc(tx, w, r) if err != nil { writeErrorRollback(tx, w, err) return } w.WriteHeader(200) w.Write(resp) } func main() { if len(os.Args) != 3 { fmt.Printf("usage: %s {db_string} {public_key_pem}", os.Args[0]) os.Exit(1) } db_string := os.Args[1] db, err := sql.Open("postgres", db_string) if err != nil { panic(err) } verify_pubkey, err := LoadKey(os.Args[2]) if err != nil { panic(err) } mux := muxpatterns.NewServeMux() mux.Handle("GET /tournaments", DBHandler{ Audience: "score_server", Resource: "tournaments:list", VerifyKey: verify_pubkey, DBConnection: db, DBOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: false}, HandleFunc: func(tx *sql.Tx, w http.ResponseWriter, r *http.Request) ([]byte, error) { return []byte(`{"response": "success"}`), nil }, }) addr := ":8080" server := &http.Server{ Addr: addr, Handler: mux, } fmt.Printf("Starting server on %s\n", addr) server.ListenAndServe() }