From d4c7870c0243e8cf70bea3ff565d23c7b22f2f3f Mon Sep 17 00:00:00 2001 From: hmoazzem Date: Thu, 8 Aug 2024 21:43:23 +0200 Subject: [PATCH] init Signed-off-by: hmoazzem --- .gitignore | 30 +++ CONTRIBUTING.md | 39 ++++ LICENSE | 201 +++++++++++++++++ README.md | 144 ++++++++++++ context.go | 100 +++++++++ doc.go | 35 +++ examples/custom-middleware/main.go | 36 +++ examples/handlers/health.go | 18 ++ examples/handlers/oidc.go | 44 ++++ examples/handlers/pgrole.go | 38 ++++ go.mod | 43 ++++ go.sum | 112 +++++++++ middleware/basic_auth.go | 70 ++++++ middleware/basic_auth_test.go | 95 ++++++++ middleware/cache.go | 61 +++++ middleware/cache_test.go | 154 +++++++++++++ middleware/cors.go | 58 +++++ middleware/cors_test.go | 90 ++++++++ middleware/logger.go | 117 ++++++++++ middleware/logger_test.go | 133 +++++++++++ middleware/middleware.go | 24 ++ middleware/oidc.go | 107 +++++++++ middleware/oidc_test.go | 132 +++++++++++ middleware/pg_authz.go | 68 ++++++ middleware/postgres.go | 131 +++++++++++ middleware/proxy.go | 70 ++++++ middleware/request_id.go | 30 +++ middleware/request_id_test.go | 87 +++++++ pkg/util/cert.go | 109 +++++++++ pkg/util/cert_test.go | 131 +++++++++++ pkg/util/jq.go | 45 ++++ pkg/util/rand/name.go | 38 ++++ pkg/x/experiments.go | 3 + pkg/x/logrepl/cdc.go | 48 ++++ pkg/x/logrepl/main.go | 150 +++++++++++++ pkg/x/logrepl/pglogrepl.go | 106 +++++++++ pkg/x/logrepl/postgres_peer.go | 65 ++++++ pkg/x/logrepl/process_v1.go | 157 +++++++++++++ pkg/x/logrepl/process_v2.go | 165 ++++++++++++++ pkg/x/logrepl/util.go | 17 ++ pkg/x/pgcache/pgcache.go | 102 +++++++++ pkg/x/pgproxy/main.go | 51 +++++ pkg/x/pgproxy/server.go | 113 ++++++++++ postgres.go | 350 +++++++++++++++++++++++++++++ router.go | 172 ++++++++++++++ router_test.go | 264 ++++++++++++++++++++++ 46 files changed, 4353 insertions(+) create mode 100644 .gitignore create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 context.go create mode 100644 doc.go create mode 100644 examples/custom-middleware/main.go create mode 100644 examples/handlers/health.go create mode 100644 examples/handlers/oidc.go create mode 100644 examples/handlers/pgrole.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 middleware/basic_auth.go create mode 100644 middleware/basic_auth_test.go create mode 100644 middleware/cache.go create mode 100644 middleware/cache_test.go create mode 100644 middleware/cors.go create mode 100644 middleware/cors_test.go create mode 100644 middleware/logger.go create mode 100644 middleware/logger_test.go create mode 100644 middleware/middleware.go create mode 100644 middleware/oidc.go create mode 100644 middleware/oidc_test.go create mode 100644 middleware/pg_authz.go create mode 100644 middleware/postgres.go create mode 100644 middleware/proxy.go create mode 100644 middleware/request_id.go create mode 100644 middleware/request_id_test.go create mode 100644 pkg/util/cert.go create mode 100644 pkg/util/cert_test.go create mode 100644 pkg/util/jq.go create mode 100644 pkg/util/rand/name.go create mode 100644 pkg/x/experiments.go create mode 100644 pkg/x/logrepl/cdc.go create mode 100644 pkg/x/logrepl/main.go create mode 100644 pkg/x/logrepl/pglogrepl.go create mode 100644 pkg/x/logrepl/postgres_peer.go create mode 100644 pkg/x/logrepl/process_v1.go create mode 100644 pkg/x/logrepl/process_v2.go create mode 100644 pkg/x/logrepl/util.go create mode 100644 pkg/x/pgcache/pgcache.go create mode 100644 pkg/x/pgproxy/main.go create mode 100644 pkg/x/pgproxy/server.go create mode 100644 postgres.go create mode 100644 router.go create mode 100644 router_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f4d5bbc --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +bin/* +Dockerfile.cross + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Go workspace file +go.work + +# Kubernetes Generated files - skip generated files, except for vendored files +!vendor/**/zz_generated.* + +# editor and IDE paraphernalia +.idea +.vscode +*.swp +*.swo +*~ + +# local stuff +__* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..3dbc79b --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing Guide + +Hello and welcome! We’re excited you want to contribute to pgo. Here’s how you can help: + +### Reporting Issues +Found a bug or have a suggestion? Open an issue and provide as much detail as possible. + +### Code Contributions +1. **Fork the Repo:** Fork the repository to your GitHub account: + ```bash + git clone git@github.com:/pgo.git + ``` +2. **Clone Your Fork:** Clone it to your local machine: + ```bash + git clone git@github.com:/pgo.git + ``` +3. **Create a Branch:** + ```bash + git checkout -b feat/feature-name # Use feat, bug, docs, etc. + ``` +4. **Make Changes:** Make your changes in the new branch. +5. **Commit Changes:** + ```bash + git commit -m "Description of changes" + ``` +6. **Push to Your Repo:** + ```bash + git push origin feat/feature-name + ``` +7. **Open a PR:** Create a pull request from your repository to the `dev` branch of `git@github.com:edgeflare/pgo.git`. Provide a clear description. + +### Code Style +Follow our coding standards. If unsure, feel free to ask. + +### Documentation +Update documentation for any new features. + + +Thank you for your contributions! \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3589627 --- /dev/null +++ b/README.md @@ -0,0 +1,144 @@ +# router and Postgres util for net/http.Handler + +## It can be useful if you: +- expose Postgres via REST API using primarily PostgREST, leverage its [authorization pattern](https://docs.postgrest.org/en/latest/explanations/db_authz.html), and want custom logic for an endpoint or two, beyond just CRUD +- use [custom OIDC token claims](https://zitadel.com/docs/apis/openidoauth/claims#custom-claims) to authorize Postgres queries (possibly with) [Postgres Row Level Security (RLS)](https://www.postgresql.org/docs/13/ddl-rowsecurity.html). pgo passes [net/http](https://pkg.go.dev/net/http) request context *eg* headers to underlying [pgxpool.Conn](https://pkg.go.dev/github.com/jackc/pgx/v5/pgxpool#Conn.Conn) +- wanna experiment with a [router.go](./router.go) written in ~50 lines of code dependent only on standard library. It's a wrapper around [http.ServeMux](https://pkg.go.dev/net/http#ServeMux) with helpers for route groups, middleware and [SQL](github.com/edgeflare/pgxutil) + +## Realworld-ish examples +- [guardian](https://github.com/edgeflare/guardian): manages [WireGuard](https://www.wireguard.com/) networks and peeers +- [fabric-oidc-proxy](https://github.com/edgeflare/fabric-oidc-proxy): allows authenticating to Hyperledger Fabric blockchain using OIDC token. It requests x509 certificate for each user from Fabric CA, and signs transactions using respective user's certificate. + +```go +// minimal error handling for brevity +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/edgeflare/guardian/wg" + "github.com/edgeflare/pgo" + mw "github.com/edgeflare/pgo/middleware" + + "github.com/jackc/pgx/v5" +) + +func main() { + port := flag.Int("port", 8080, "port to run the server on") + flag.Parse() + + r := pgo.NewRouter() + + // (optional) middleware with default options + r.Use(mw.RequestID) + r.Use(mw.LoggerWithOptions(nil)) + r.Use(mw.CORSWithOptions(nil)) + + // route group for API v1 + apiv1 := r.Group("/api/v1") + + // OIDC middleware for authentication + oidcConfig := mw.OIDCProviderConfig{ + ClientID: os.Getenv("PGO_OIDC_CLIENT_ID"), + ClientSecret: os.Getenv("PGO_OIDC_CLIENT_SECRET"), + Issuer: os.Getenv("PGO_OIDC_ISSUER"), + } + apiv1.Use(mw.VerifyOIDCToken(oidcConfig)) + + // Postgres configuration for authorization + pgConfig := mw.PgConfig{ + ConnString: os.Getenv("PGO_POSTGRES_CONN_STRING"), + } + pgmw := mw.Postgres(pgConfig, + mw.PgOIDCAuthz(oidcConfig, os.Getenv("PGO_POSTGRES_OIDC_ROLE_CLAIM_KEY")), + ) + apiv1.Use(pgmw) + + // Respond with all networks where user_id == authenticated user.Subject + apiv1.Handle("GET /networks", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := `SELECT id, name, addr, addr6, dns, user_id, info, domains, created_at, + updated_at, uuid FROM networks WHERE user_id = $1` + + user, _ := pgo.OIDCUser(r) + pgo.SelectAndRespondJSON[wg.Network](w, r, query, []any{user.Subject}, pgx.RowToStructByPos[wg.Network]) + })) + + apiv1.Handle("POST /networks", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqNet wg.Network + if err := pgo.BindOrRespondError(r, w, &reqNet); err != nil { + return + } + + user, _ := pgo.OIDCUser(r) + reqNet.UserID = user.Subject + network := wg.NewDefaultNetwork(reqNet) + networkMap := pgo.RowMap(network) + + if _, pgErr := pgo.InsertRow(r, "networks", networkMap); pgErr != nil { + pgo.RespondError(w, pgo.PgErrorCodeToHTTPStatus(pgErr.Error()), pgErr.Error()) + return + } + pgo.RespondJSON(w, http.StatusCreated, network) + })) + + // Respond with all peers created by the authenticated user and network_id == network path parameter + apiv1.Handle("GET /networks/{network}/peers", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := `SELECT id, name, addr, cidr, dns, network_id, allowed_ips, endpoint, + enabled, type, pubkey, privkey, presharedkey, info, user_id, listen_port, mtu, created_at, updated_at, + uuid FROM peers WHERE network_id = $1 AND user_id = $2` + + user, _ := pgo.OIDCUser(r) + pgo.SelectAndRespondJSON[wg.Peer](w, r, query, []any{r.PathValue("network"), user.Subject}, pgx.RowToStructByPos[wg.Peer]) + })) + + apiv1.Handle("POST /networks/{network}/peers", http.HandlerFunc(postPeerHandler)) + apiv1.Handle("GET /networks/{network}/peers/{peer}", http.HandlerFunc(getPeerConfigHandler)) + apiv1.Handle("GET /networks/{network}/peers/{peer}/qr", http.HandlerFunc(getPeerConfigQrHandler)) + apiv1.Handle("DELETE /networks/{network}", http.HandlerFunc(deleteNetworkHandler)) + apiv1.Handle("DELETE /networks/{network}/peers/{peer}", http.HandlerFunc(deletePeerHandler)) + + // Run server in a goroutine + go func() { + if err := r.ListenAndServe(fmt.Sprintf(":%d", *port)); err != nil && err != http.ErrServerClosed { + log.Fatalf("Server error: %v", err) + } + }() + + fmt.Printf("Server is running on port %d\n", port) + + // Set up signal handling + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + + // Wait for SIGINT or SIGTERM + <-stop + + fmt.Println("Shutting down server...") + + // Create a deadline for the shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Attempt graceful shutdown + if err := r.Shutdown(ctx); err != nil { + fmt.Printf("server forced to shutdown: %s", err) + } + fmt.Println("Server gracefully stopped") +} +``` + +## Contributing +Please see [CONTRIBUTING.md](CONTRIBUTING.md). + +## License +Apache License 2.0 + + diff --git a/context.go b/context.go new file mode 100644 index 0000000..b73bbc6 --- /dev/null +++ b/context.go @@ -0,0 +1,100 @@ +package pgo + +import ( + "encoding/json" + "net/http" + + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type ContextKey string + +const ( + RequestIDCtxKey ContextKey = "RequestID" + LogEntryCtxKey ContextKey = "LogEntry" + OIDCUserCtxKey ContextKey = "OIDCUser" + BasicAuthCtxKey ContextKey = "BasicAuth" + PgConnCtxKey ContextKey = "PgConn" + PgRoleCtxKey ContextKey = "PgRole" +) + +// OIDCUser extracts the OIDC user from the request context. +func OIDCUser(r *http.Request) (*oidc.IntrospectionResponse, bool) { + user, ok := r.Context().Value(OIDCUserCtxKey).(*oidc.IntrospectionResponse) + if !ok || user == nil { + return nil, false + } + return user, true +} + +// BasicAuthUser retrieves the authenticated username from the context. +func BasicAuthUser(r *http.Request) (string, bool) { + user, ok := r.Context().Value(BasicAuthCtxKey).(string) + return user, ok +} + +// BindOrRespondError decodes the JSON body of an HTTP request, r, into the given destination object, dst. +// If decoding fails, it responds with a 400 Bad Request error. +func BindOrRespondError(r *http.Request, w http.ResponseWriter, dst interface{}) error { + if err := json.NewDecoder(r.Body).Decode(dst); err != nil { + RespondError(w, http.StatusBadRequest, err.Error()) + return err + } + return nil +} + +// RespondJSON writes a JSON response with the given status code and data. +func RespondJSON(w http.ResponseWriter, statusCode int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(data); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } +} + +// RespondText writes a plain text response with the given status code and text content. +func RespondText(w http.ResponseWriter, statusCode int, text string) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(statusCode) + if _, err := w.Write([]byte(text)); err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } +} + +// RespondHTML writes an HTML response with the given status code and HTML content. +func RespondHTML(w http.ResponseWriter, statusCode int, html string) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(statusCode) + if _, err := w.Write([]byte(html)); err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } +} + +// RespondBinary writes a binary response with the given status code and data. +func RespondBinary(w http.ResponseWriter, statusCode int, data []byte, contentType string) { + w.Header().Set("Content-Type", contentType) + w.WriteHeader(statusCode) + if _, err := w.Write(data); err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + } +} + +// ErrorResponse represents a structured error response. +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// RespondError sends a JSON response with an error code and message. +func RespondError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + errorResponse := ErrorResponse{ + Code: statusCode, + Message: message, + } + if err := json.NewEncoder(w).Encode(errorResponse); err != nil { + // Fallback if JSON encoding fails + http.Error(w, "Failed to encode error response", http.StatusInternalServerError) + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..f493884 --- /dev/null +++ b/doc.go @@ -0,0 +1,35 @@ +// Package pgo simplifies the process of querying PostgreSQL database upon an HTTP request. It uses [jackc/pgx](https://pkg.go.dev/github.com/jackc/pgx/v5) as the PostgreSQL driver. +// +// Key Features: +// +// - HTTP Routing: A lightweight router with support for middleware, grouping, and customizable error handling. +// +// - PostgreSQL Integration: Streamlined interactions with PostgreSQL databases using pgx, a powerful PostgreSQL driver for Go. +// +// - Convenience Functions: A collection of helper functions for common database operations (select, insert, update, etc.), +// designed to work seamlessly within HTTP request contexts. +// +// - Error Handling: Robust error handling mechanisms for both HTTP routing and database operations, ensuring predictable +// and informative error responses. +// +// Example Usage: +// +// router := pgo.NewRouter(pgo.WithTLS("cert.pem", "key.pem")) +// router.Use(middleware.Logger) +// +// // Define routes +// router.Handle("GET /users", usersHandler) +// router.Handle("POST /users", createUserHandler) +// +// api := router.Group("/api") +// api.Handle("GET /products", productsHandler) +// +// // Start the server +// router.ListenAndServe(":8080") +// +// Additional Information: +// +// - For detailed information on HTTP routing, see the documentation for the `Router` type and its associated functions. +// +// - For PostgreSQL-specific helpers and utilities, refer to the documentation for the `pgxutil` package. +package pgo diff --git a/examples/custom-middleware/main.go b/examples/custom-middleware/main.go new file mode 100644 index 0000000..9879de8 --- /dev/null +++ b/examples/custom-middleware/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "log" + "net/http" + + "github.com/edgeflare/pgo" +) + +func main() { + r := pgo.NewRouter() + + // custom logger middleware + r.Use(customLogger) + + // Group with prefix "/api/v1" + v1 := r.Group("/api/v1") + + // Handle routes in the "/api" group + v1.Handle("GET /users/{user}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("User endpoint: %s", r.PathValue("user")))) + })) + v1.Handle("POST /products/{product}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("Product endpoint: %s", r.PathValue("product")))) + })) + + log.Fatal(r.ListenAndServe(":8080")) +} + +func customLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println(r.Method, r.URL.Path) + next.ServeHTTP(w, r) + }) +} diff --git a/examples/handlers/health.go b/examples/handlers/health.go new file mode 100644 index 0000000..2b8fe2d --- /dev/null +++ b/examples/handlers/health.go @@ -0,0 +1,18 @@ +package handlers + +import ( + "net/http" + + "github.com/edgeflare/pgo" +) + +// HealthHandler returns the request ID as a plain text response with a 200 OK status code. +func HealthHandler(w http.ResponseWriter, r *http.Request) { + requestID := r.Context().Value(pgo.RequestIDCtxKey).(string) + if requestID == "" { + http.Error(w, "Request ID not found", http.StatusInternalServerError) + return + } + + pgo.RespondText(w, http.StatusOK, requestID) +} diff --git a/examples/handlers/oidc.go b/examples/handlers/oidc.go new file mode 100644 index 0000000..83d3de2 --- /dev/null +++ b/examples/handlers/oidc.go @@ -0,0 +1,44 @@ +package handlers + +import ( + "net/http" + + "github.com/edgeflare/pgo" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type AuthzResponse struct { + Allowed bool `json:"allowed"` +} + +// GetMyAccountHandler retrieves the user information from the OIDC context and responds with the user details. +// +// If the user is not found in the context or an error occurs, an HTTP 401 Unauthorized error is returned. +func GetMyAccountHandler(w http.ResponseWriter, r *http.Request) { + var user *oidc.IntrospectionResponse + if user, ok := pgo.OIDCUser(r); !ok || user == nil { + http.Error(w, "User not found in context", http.StatusUnauthorized) + return + } + pgo.RespondJSON(w, http.StatusOK, user) +} + +// GetSimpleAuthzHandler performs an authorization check based on a requested /endpoint/{claim}/{value} path +// eg `GET /endpoint/editor/true` will check if the user has the claim `editor` with the value `true` +// The response is an `AuthzResponse` object indicating whether the user is authorized (Allowed: true) or not (Allowed: false) +func GetSimpleAuthzHandler(w http.ResponseWriter, r *http.Request) { + user, ok := pgo.OIDCUser(r) + if !ok { + pgo.RespondError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + requestedClaim := r.PathValue("claim") + requestedValue := r.PathValue("value") + if value, ok := user.Claims[requestedClaim].(string); !ok || value != requestedValue { + pgo.RespondJSON(w, http.StatusOK, AuthzResponse{Allowed: false}) + return + } + + pgo.RespondJSON(w, http.StatusOK, AuthzResponse{Allowed: true}) +} diff --git a/examples/handlers/pgrole.go b/examples/handlers/pgrole.go new file mode 100644 index 0000000..c0c5188 --- /dev/null +++ b/examples/handlers/pgrole.go @@ -0,0 +1,38 @@ +package handlers + +import ( + "net/http" + + "github.com/edgeflare/pgo" + "github.com/jackc/pgx/v5/pgxpool" +) + +func GetMyPgRoleHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // retrive the role from the request context, set by the middleware + ctxRole := r.Context().Value(pgo.PgRoleCtxKey) + + // retrieve the connection from request context + conn, ok := r.Context().Value(pgo.PgConnCtxKey).(*pgxpool.Conn) + if !ok || conn == nil { + http.Error(w, "Failed to get connection from context", http.StatusInternalServerError) + return + } + defer conn.Release() + + // query the current role using the connection + var queryRole string + err := conn.Conn().QueryRow(r.Context(), "SELECT current_role").Scan(&queryRole) + if err != nil { + http.Error(w, "Failed to query current role", http.StatusInternalServerError) + return + } + + // construct both roles as a json response + roles := map[string]string{ + "ctx_role": ctxRole.(string), + "query_role": queryRole, + } + pgo.RespondJSON(w, http.StatusOK, roles) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9f1f441 --- /dev/null +++ b/go.mod @@ -0,0 +1,43 @@ +module github.com/edgeflare/pgo + +go 1.23rc2 + +require ( + github.com/edgeflare/pgxutil v0.0.0-20240802003737-b6dfe049d40f + github.com/google/uuid v1.6.0 + github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 + github.com/jackc/pgx/v5 v5.6.0 + github.com/pganalyze/pg_query_go/v5 v5.1.0 + github.com/stretchr/testify v1.9.0 + github.com/zitadel/oidc/v3 v3.26.0 + go.uber.org/zap v1.27.0 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/muhlemmer/gu v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/zitadel/logging v0.6.0 // indirect + github.com/zitadel/schema v1.3.0 // indirect + go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/trace v1.28.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.21.0 // indirect + golang.org/x/text v0.16.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f38b54c --- /dev/null +++ b/go.sum @@ -0,0 +1,112 @@ +github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= +github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/edgeflare/pgxutil v0.0.0-20240802003737-b6dfe049d40f h1:SH7kI8XUwnNe4oJrX+599esPfsZlhK9ZBjdRZziYCwQ= +github.com/edgeflare/pgxutil v0.0.0-20240802003737-b6dfe049d40f/go.mod h1:WThYGm9KD+GvZnWMqGxOHiCE+3lPlUweJmnMYFwKN7E= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 h1:86CQbMauoZdLS0HDLcEHYo6rErjiCBjVvcxGsioIn7s= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9/go.mod h1:SO15KF4QqfUM5UhsG9roXre5qeAQLC1rm8a8Gjpgg5k= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= +github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= +github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= +github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY= +github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0= +github.com/pganalyze/pg_query_go/v5 v5.1.0 h1:MlxQqHZnvA3cbRQYyIrjxEjzo560P6MyTgtlaf3pmXg= +github.com/pganalyze/pg_query_go/v5 v5.1.0/go.mod h1:FsglvxidZsVN+Ltw3Ai6nTgPVcK2BPukH3jCDEqc1Ug= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/zitadel/logging v0.6.0 h1:t5Nnt//r+m2ZhhoTmoPX+c96pbMarqJvW1Vq6xFTank= +github.com/zitadel/logging v0.6.0/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow= +github.com/zitadel/oidc/v3 v3.26.0 h1:BG3OUK+JpuKz7YHJIyUxL5Sl2JV6ePkG42UP4Xv3J2w= +github.com/zitadel/oidc/v3 v3.26.0/go.mod h1:Cx6AYPTJO5q2mjqF3jaknbKOUjpq1Xui0SYvVhkKuXU= +github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= +github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= +go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= +go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= +go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= +go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go new file mode 100644 index 0000000..861c013 --- /dev/null +++ b/middleware/basic_auth.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/edgeflare/pgo" +) + +// BasicAuthConfig holds the username-password pairs for basic authentication. +type BasicAuthConfig struct { + Credentials map[string]string +} + +// NewBasicAuthCreds creates a new instance of BasicAuthConfig with multiple username/password pairs. +func BasicAuthCreds(credentials map[string]string) *BasicAuthConfig { + return &BasicAuthConfig{ + Credentials: credentials, + } +} + +// VerifyBasicAuth is a middleware function for basic authentication. +func VerifyBasicAuth(config *BasicAuthConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "Authorization header missing", http.StatusUnauthorized) + return + } + + // Check if the authorization header is in the correct format + if !strings.HasPrefix(authHeader, "Basic ") { + http.Error(w, "Invalid authorization format", http.StatusUnauthorized) + return + } + + // Decode the base64 encoded credentials + encodedCredentials := strings.TrimPrefix(authHeader, "Basic ") + credentials, err := base64.StdEncoding.DecodeString(encodedCredentials) + if err != nil { + http.Error(w, "Invalid base64 encoding", http.StatusUnauthorized) + return + } + + // Split the credentials into username and password + creds := strings.SplitN(string(credentials), ":", 2) + if len(creds) != 2 { + http.Error(w, "Invalid credentials format", http.StatusUnauthorized) + return + } + + username, password := creds[0], creds[1] + + // Verify the credentials + if validPassword, ok := config.Credentials[username]; !ok || validPassword != password { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "Invalid credentials", http.StatusUnauthorized) + return + } + + // Store authenticated user in context + ctx := context.WithValue(r.Context(), pgo.BasicAuthCtxKey, username) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go new file mode 100644 index 0000000..b642bf7 --- /dev/null +++ b/middleware/basic_auth_test.go @@ -0,0 +1,95 @@ +package middleware + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/edgeflare/pgo" +) + +func TestVerifyBasicAuth(t *testing.T) { + tests := []struct { + name string + config *BasicAuthConfig + authHeader string + expectedStatus int + expectedBody string + expectedUser string + }{ + { + name: "missing authorization header", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "", + expectedStatus: http.StatusUnauthorized, + expectedBody: "Authorization header missing\n", + }, + { + name: "invalid authorization format", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "Bearer some-token", + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid authorization format\n", + }, + { + name: "invalid base64 encoding", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "Basic invalid-base64", + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid base64 encoding\n", + }, + { + name: "invalid credentials format", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("userpass")), + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid credentials format\n", + }, + { + name: "invalid credentials", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:wrongpass")), + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid credentials\n", + }, + { + name: "valid credentials", + config: BasicAuthCreds(map[string]string{"user": "pass"}), + authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")), + expectedStatus: http.StatusOK, + expectedBody: "OK", + expectedUser: "user", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + rr := httptest.NewRecorder() + + handler := VerifyBasicAuth(tt.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(pgo.BasicAuthCtxKey).(string) + if !ok || user != tt.expectedUser { + http.Error(w, "User not found in context", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("status code: expected %v, got %v", tt.expectedStatus, status) + } + + if body := rr.Body.String(); body != tt.expectedBody { + t.Errorf("body: expected %v, got %v", tt.expectedBody, body) + } + }) + } +} diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..19cf5b0 --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,61 @@ +package middleware + +import ( + "sync" + "time" +) + +// Cache is a simple in-memory cache with expiration +type Cache struct { + sync.RWMutex + items map[string]cacheItem +} + +// cacheItem holds cached data along with its expiration +type cacheItem struct { + value interface{} + expiration time.Time +} + +// NewCache creates a new Cache +func NewCache() *Cache { + return &Cache{ + items: make(map[string]cacheItem), + } +} + +// Set adds an item to the cache with a specified expiration duration +func (c *Cache) Set(key string, value interface{}, duration time.Duration) { + c.Lock() + defer c.Unlock() + c.items[key] = cacheItem{ + value: value, + expiration: time.Now().Add(duration), + } +} + +// Get retrieves an item from the cache +func (c *Cache) Get(key string) (interface{}, bool) { + c.RLock() + defer c.RUnlock() + item, found := c.items[key] + if !found { + return nil, false + } + if time.Now().After(item.expiration) { + delete(c.items, key) + return nil, false + } + return item.value, true +} + +// CleanupExpired removes expired items from the cache +func (c *Cache) CleanupExpired() { + c.Lock() + defer c.Unlock() + for key, item := range c.items { + if time.Now().After(item.expiration) { + delete(c.items, key) + } + } +} diff --git a/middleware/cache_test.go b/middleware/cache_test.go new file mode 100644 index 0000000..8c33be8 --- /dev/null +++ b/middleware/cache_test.go @@ -0,0 +1,154 @@ +package middleware + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestCache_SetAndGet(t *testing.T) { + cache := NewCache() + + // Test setting and getting a value + cache.Set("key1", "value1", 1*time.Minute) + value, found := cache.Get("key1") + if !found { + t.Error("Expected to find key1") + } + if value != "value1" { + t.Errorf("Expected value1, got %v", value) + } + + // Test getting a non-existent key + _, found = cache.Get("nonexistent") + if found { + t.Error("Expected not to find nonexistent key") + } +} + +func TestCache_Expiration(t *testing.T) { + cache := NewCache() + + // Set a key with a short expiration + cache.Set("short", "value", 10*time.Millisecond) + + // Wait for expiration + time.Sleep(20 * time.Millisecond) + + // Try to get the expired key + _, found := cache.Get("short") + if found { + t.Error("Expected short to be expired") + } +} + +func TestCache_CleanupExpired(t *testing.T) { + cache := NewCache() + + // Set some keys with different expiration times + cache.Set("expired1", "value1", 10*time.Millisecond) + cache.Set("expired2", "value2", 10*time.Millisecond) + cache.Set("valid", "value3", 1*time.Minute) + + // Wait for some keys to expire + time.Sleep(20 * time.Millisecond) + + // Run cleanup + cache.CleanupExpired() + + // Check if expired keys were removed and valid key remains + _, found1 := cache.Get("expired1") + _, found2 := cache.Get("expired2") + _, found3 := cache.Get("valid") + + if found1 || found2 { + t.Error("Expected expired keys to be removed") + } + if !found3 { + t.Error("Expected valid key to remain") + } +} + +func TestCache_Concurrency(t *testing.T) { + cache := NewCache() + var wg sync.WaitGroup + concurrency := 100 + + // Concurrent writes + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + cache.Set(key, i, 1*time.Minute) + }(i) + } + + // Concurrent reads + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + _, _ = cache.Get(key) + }(i) + } + + wg.Wait() + + // Verify all keys are present + for i := 0; i < concurrency; i++ { + key := fmt.Sprintf("key%d", i) + _, found := cache.Get(key) + if !found { + t.Errorf("Expected to find key: %s", key) + } + } +} + +func BenchmarkCache_Set(b *testing.B) { + cache := NewCache() + for i := 0; i < b.N; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 1*time.Minute) + } +} + +func BenchmarkCache_Get(b *testing.B) { + cache := NewCache() + for i := 0; i < 1000; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 1*time.Minute) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.Get(fmt.Sprintf("key%d", i%1000)) + } +} + +func BenchmarkCache_SetParallel(b *testing.B) { + cache := NewCache() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + cache.Set(fmt.Sprintf("key%d", i), i, 1*time.Minute) + i++ + } + }) +} + +func BenchmarkCache_GetParallel(b *testing.B) { + cache := NewCache() + for i := 0; i < 1000; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 1*time.Minute) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + cache.Get(fmt.Sprintf("key%d", i%1000)) + i++ + } + }) +} diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..210864d --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// CORSOptions defines configuration for CORS. +type CORSOptions struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string + AllowCredentials bool +} + +// defaultCORSOptions returns the default CORS options. +func defaultCORSOptions() *CORSOptions { + return &CORSOptions{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With"}, + AllowCredentials: true, + } +} + +// CORSWithOptions creates a CORS middleware with the provided configuration. +// If options is nil, it will use the default CORS settings. +// If options is an empty struct (CORSOptions{}), it will create a middleware with no CORS headers. +func CORSWithOptions(options *CORSOptions) func(http.Handler) http.Handler { + if options == nil { + options = defaultCORSOptions() + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(options.AllowedOrigins) > 0 { + w.Header().Set("Access-Control-Allow-Origin", strings.Join(options.AllowedOrigins, ",")) + } + if len(options.AllowedMethods) > 0 { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(options.AllowedMethods, ",")) + } + if len(options.AllowedHeaders) > 0 { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(options.AllowedHeaders, ",")) + } + if options.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // Handle preflight request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/middleware/cors_test.go b/middleware/cors_test.go new file mode 100644 index 0000000..30942bc --- /dev/null +++ b/middleware/cors_test.go @@ -0,0 +1,90 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCORSWithOptions(t *testing.T) { + tests := []struct { + name string + options *CORSOptions + method string + expectedHeaders map[string]string + expectedStatus int + }{ + { + name: "default options", + method: http.MethodGet, + options: defaultCORSOptions(), + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST,PUT,DELETE,OPTIONS", + "Access-Control-Allow-Headers": "Content-Type,Content-Length,Accept-Encoding,X-CSRF-Token,Authorization,accept,origin,Cache-Control,X-Requested-With", + "Access-Control-Allow-Credentials": "true", + }, + expectedStatus: http.StatusOK, + }, + { + name: "custom options", + method: http.MethodGet, + options: &CORSOptions{ + AllowedOrigins: []string{"http://example.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedHeaders: []string{"Content-Type"}, + AllowCredentials: false, + }, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET,POST", + "Access-Control-Allow-Headers": "Content-Type", + }, + expectedStatus: http.StatusOK, + }, + { + name: "empty options", + method: http.MethodGet, + options: &CORSOptions{}, + expectedHeaders: map[string]string{ + // No CORS headers should be set + }, + expectedStatus: http.StatusOK, + }, + { + name: "preflight request", + method: http.MethodOptions, + options: defaultCORSOptions(), + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST,PUT,DELETE,OPTIONS", + "Access-Control-Allow-Headers": "Content-Type,Content-Length,Accept-Encoding,X-CSRF-Token,Authorization,accept,origin,Cache-Control,X-Requested-With", + "Access-Control-Allow-Credentials": "true", + }, + expectedStatus: http.StatusNoContent, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest(tt.method, "http://example.com", nil) + rr := httptest.NewRecorder() + + handler := CORSWithOptions(tt.options)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(rr, req) + + for header, expectedValue := range tt.expectedHeaders { + if value := rr.Header().Get(header); value != expectedValue { + t.Errorf("header %s: expected %v, got %v", header, expectedValue, value) + } + } + + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("status code: expected %v, got %v", tt.expectedStatus, status) + } + }) + } +} diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 0000000..13a554c --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,117 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/edgeflare/pgo" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// ResponseRecorder is a wrapper for http.ResponseWriter to capture status codes and durations. +type ResponseRecorder struct { + http.ResponseWriter + StatusCode int + start time.Time +} + +func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { + return &ResponseRecorder{ + ResponseWriter: w, + StatusCode: http.StatusOK, + start: time.Now(), + } +} + +func (rr *ResponseRecorder) WriteHeader(statusCode int) { + rr.StatusCode = statusCode + rr.ResponseWriter.WriteHeader(statusCode) +} + +func (rr *ResponseRecorder) Write(b []byte) (int, error) { + return rr.ResponseWriter.Write(b) +} + +// Retrieve log metadata from context +func GetLogEntryMetadata(ctx context.Context) map[string]interface{} { + if metadata, ok := ctx.Value(pgo.LogEntryCtxKey).(map[string]interface{}); ok { + return metadata + } + return nil +} + +// LoggerOptions defines configuration for the logger middleware. +type LoggerOptions struct { + Logger *zap.Logger + Format func(reqID string, rec *ResponseRecorder, r *http.Request, latency time.Duration) []zap.Field +} + +var defaultLogger *zap.Logger + +func init() { + var err error + defaultLogger, err = zap.NewProduction() + if err != nil { + panic(err) + } + defer defaultLogger.Sync() +} + +func LoggerWithOptions(options *LoggerOptions) func(http.Handler) http.Handler { + if options == nil { + options = &LoggerOptions{Logger: defaultLogger} + } + + if options.Format == nil { + options.Format = func(reqID string, rec *ResponseRecorder, r *http.Request, latency time.Duration) []zap.Field { + return []zap.Field{ + zap.String("req_id", reqID), + zap.Int("status", rec.StatusCode), + zap.String("method", r.Method), + zap.String("host", r.Host), + zap.String("url", r.URL.String()), + zap.String("remote_addr", r.RemoteAddr), + zap.String("user_agent", r.UserAgent()), + zap.Duration("latency", latency), + } + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + if _, ok := r.Context().Value(pgo.LogEntryCtxKey).(*zap.Logger); !ok { + reqID, ok := r.Context().Value(pgo.RequestIDCtxKey).(string) + if !ok { + reqID = uuid.Nil.String() + } + + rec := NewResponseRecorder(w) + // try to minimize the data passed via context + ctx := context.WithValue(r.Context(), pgo.LogEntryCtxKey, options.Logger) + r = r.WithContext(ctx) + + next.ServeHTTP(rec, r) + + latency := time.Since(start) + + pgRole, ok := r.Context().Value(pgo.PgRoleCtxKey).(string) + if !ok { + fmt.Println("PG_ROLE: ", pgRole) + pgRole = "unknown" + } + + fmt.Println("NOT WORKING... PG_ROLE: ", pgRole) + + fields := options.Format(reqID, rec, r, latency) + fields = append(fields, zap.String("pg_role", pgRole)) + options.Logger.Info("response", fields...) + } else { + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/middleware/logger_test.go b/middleware/logger_test.go new file mode 100644 index 0000000..42e939a --- /dev/null +++ b/middleware/logger_test.go @@ -0,0 +1,133 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/edgeflare/pgo" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +// newTestLogger creates a logger for testing purposes and captures logs. +func newTestLogger() (*zap.Logger, *observer.ObservedLogs) { + core, logs := observer.New(zap.InfoLevel) + logger := zap.New(core) + return logger, logs +} + +// TestGetLogEntryMetadata tests the GetLogEntryMetadata function. +func TestGetLogEntryMetadata(t *testing.T) { + ctx := context.WithValue(context.Background(), pgo.LogEntryCtxKey, map[string]interface{}{"foo": "bar"}) + metadata := GetLogEntryMetadata(ctx) + require.NotNil(t, metadata) + assert.Equal(t, "bar", metadata["foo"]) + + ctx = context.Background() + metadata = GetLogEntryMetadata(ctx) + assert.Nil(t, metadata) +} + +// TestLoggerWithOptions tests the LoggerWithOptions middleware. +func TestLoggerWithOptions(t *testing.T) { + logger, logs := newTestLogger() + options := &LoggerOptions{ + Logger: logger, + Format: func(reqID string, rec *ResponseRecorder, r *http.Request, latency time.Duration) []zap.Field { + return []zap.Field{ + zap.String("test", "log"), + } + }, + } + middleware := LoggerWithOptions(options) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 1, logs.Len()) + assert.Equal(t, "response", logs.All()[0].Message) + assert.Equal(t, "log", logs.All()[0].ContextMap()["test"]) +} + +// TestLoggerWithDefaultOptions tests the LoggerWithOptions middleware with default options. +func TestLoggerWithDefaultOptions(t *testing.T) { + logger, logs := newTestLogger() + defaultLogger = logger + middleware := LoggerWithOptions(nil) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 1, logs.Len()) + assert.Equal(t, "response", logs.All()[0].Message) + assert.Equal(t, "GET", logs.All()[0].ContextMap()["method"]) +} + +// TestLoggerWithoutRequestID tests the LoggerWithOptions middleware without request ID in context. +func TestLoggerWithoutRequestID(t *testing.T) { + logger, logs := newTestLogger() + options := &LoggerOptions{ + Logger: logger, + } + middleware := LoggerWithOptions(options) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 1, logs.Len()) + assert.Equal(t, "response", logs.All()[0].Message) + assert.Equal(t, uuid.Nil.String(), logs.All()[0].ContextMap()["req_id"]) +} + +// TestLoggerWithRequestID tests the LoggerWithOptions middleware with request ID in context. +func TestLoggerWithRequestID(t *testing.T) { + logger, logs := newTestLogger() + options := &LoggerOptions{ + Logger: logger, + } + middleware := LoggerWithOptions(options) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + reqID := uuid.New().String() + ctx := context.WithValue(req.Context(), pgo.RequestIDCtxKey, reqID) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 1, logs.Len()) + assert.Equal(t, "response", logs.All()[0].Message) + assert.Equal(t, reqID, logs.All()[0].ContextMap()["req_id"]) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..c94d943 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "net/http" +) + +// Middleware is a function that wraps an HTTP handler. +type Middleware func(http.Handler) http.Handler + +// middlewareRegistry manages middleware functions. +var middlewareRegistry []Middleware + +// Register adds a new middleware function to the registry. +func Register(m Middleware) { + middlewareRegistry = append(middlewareRegistry, m) +} + +// Apply applies all registered middleware functions to the given handler. +func Apply(h http.Handler) http.Handler { + for i := len(middlewareRegistry) - 1; i >= 0; i-- { + h = middlewareRegistry[i](h) + } + return h +} diff --git a/middleware/oidc.go b/middleware/oidc.go new file mode 100644 index 0000000..7478fbb --- /dev/null +++ b/middleware/oidc.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "context" + "log" + "net/http" + "strings" + "sync" + + "github.com/edgeflare/pgo" + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// OIDCProvider is the main OIDC provider +type OIDCProvider struct { + config OIDCProviderConfig + provider rs.ResourceServer + cache *Cache + mu sync.RWMutex +} + +// OIDCProviderConfig holds the configuration for the OIDC provider +type OIDCProviderConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Issuer string `json:"issuer"` +} + +var ( + oidcProvider *OIDCProvider + oidcInitOnce sync.Once +) + +// VerifyOIDCToken is middleware that verifies OIDC tokens in Authorization headers. +// By default, it sends a 401 Unauthorized response if the token is missing or invalid. +// If send401Unauthorized is false, it allows requests with other authorization schemes +// (e.g., Basic Auth) to continue without interference. +func VerifyOIDCToken(oidcCfg OIDCProviderConfig, send401Unauthorized ...bool) func(http.Handler) http.Handler { + send401 := true // Default behavior: Send 401 on failure + if len(send401Unauthorized) > 0 { + send401 = send401Unauthorized[0] + } + + return func(next http.Handler) http.Handler { + oidcInitOnce.Do(func() { + if oidcProvider == nil { + oidcProvider = InitOIDCProvider(oidcCfg) + } + }) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + if send401 { + http.Error(w, "Authorization header missing", http.StatusUnauthorized) + return + } else { + // No Authorization header and send401Unauthorized is false, + // so let other middleware/handlers handle it + next.ServeHTTP(w, r) + return + } + } + + // Check for "Bearer" token (case-insensitive) + if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + if send401 { + http.Error(w, "Invalid token format", http.StatusUnauthorized) + return + } else { + // Other authorization scheme present and send401Unauthorized is false + next.ServeHTTP(w, r) + return + } + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + + user, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), oidcProvider.provider, tokenString) + if err != nil || user == nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), pgo.OIDCUserCtxKey, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func InitOIDCProvider(cfg OIDCProviderConfig) *OIDCProvider { + if cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.Issuer == "" { + panic("missing required OIDC configuration") + } + + provider, err := rs.NewResourceServerClientCredentials(context.Background(), cfg.Issuer, cfg.ClientID, cfg.ClientSecret) + if err != nil { + log.Fatalf("Failed to create OIDC provider: %v", err) + } + + return &OIDCProvider{ + config: cfg, + provider: provider, + cache: NewCache(), + } +} diff --git a/middleware/oidc_test.go b/middleware/oidc_test.go new file mode 100644 index 0000000..352dc33 --- /dev/null +++ b/middleware/oidc_test.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "testing" + + "github.com/edgeflare/pgo/pkg/util" +) + +func TestExtractRoleFromClaims(t *testing.T) { + tests := []struct { + name string + claims map[string]interface{} + path string + expected string + expectErr bool + }{ + { + name: "Simple path", + claims: map[string]interface{}{ + "role": "admin", + }, + path: "role", + expected: "admin", + expectErr: false, + }, + { + name: "Nested path", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "role": "user", + }, + }, + path: "user.role", + expected: "user", + expectErr: false, + }, + { + name: "Array index", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "roles": []interface{}{"admin", "user"}, + }, + }, + path: "user.roles[0]", + expected: "admin", + expectErr: false, + }, + { + name: "Invalid array index", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "roles": []interface{}{"admin", "user"}, + }, + }, + path: "user.roles[2]", + expected: "", + expectErr: true, + }, + { + name: "Initial dot in path", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "role": "admin", + }, + }, + path: ".user.role", + expected: "admin", + expectErr: false, + }, + { + name: "Mixed array and nested path", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "roles": []interface{}{ + map[string]interface{}{ + "type": "admin", + }, + }, + }, + }, + path: "user.roles[0].type", + expected: "admin", + expectErr: false, + }, + { + name: "Path with non-string final value", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "role": 123, + }, + }, + path: "user.role", + expected: "", + expectErr: true, + }, + { + name: "Non-existent path", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "role": "user", + }, + }, + path: "user.nonexistent", + expected: "", + expectErr: true, + }, + { + name: "Invalid JSON path syntax", + claims: map[string]interface{}{ + "user": map[string]interface{}{ + "roles": []interface{}{"admin", "user"}, + }, + }, + path: "user.roles[abc]", + expected: "", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := util.Jq(tt.claims, tt.path) + if (err != nil) != tt.expectErr { + t.Errorf("extractRoleFromClaims() error = %v, expectErr %v", err, tt.expectErr) + return + } + if result != tt.expected { + t.Errorf("extractRoleFromClaims() = %v, expected %v", result, tt.expected) + } + }) + } +} diff --git a/middleware/pg_authz.go b/middleware/pg_authz.go new file mode 100644 index 0000000..dced229 --- /dev/null +++ b/middleware/pg_authz.go @@ -0,0 +1,68 @@ +package middleware + +import ( + "context" + "os" + + "github.com/edgeflare/pgo" + "github.com/edgeflare/pgo/pkg/util" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// AuthzResponse represents the result of an authorization check +type AuthzResponse struct { + Role string `json:"role"` + Allowed bool `json:"allowed"` +} + +// AuthzFunc defines the function signature for authorization checks +type AuthzFunc func(ctx context.Context) (AuthzResponse, error) + +// PgOIDCAuthz is the main authorization function +func PgOIDCAuthz(oidcCfg OIDCProviderConfig, pgRoleClaimKey string) AuthzFunc { + oidcInitOnce.Do(func() { + if oidcProvider == nil { + oidcProvider = InitOIDCProvider(oidcCfg) + } + }) + + return func(ctx context.Context) (AuthzResponse, error) { + user, ok := ctx.Value(pgo.OIDCUserCtxKey).(*oidc.IntrospectionResponse) + if !ok { + return AuthzResponse{Allowed: false}, nil + } + + pgrole, err := util.Jq(user.Claims, pgRoleClaimKey) + if err != nil { + return AuthzResponse{Allowed: false}, nil + } + + ctx = context.WithValue(ctx, pgo.PgRoleCtxKey, pgrole) + return AuthzResponse{Role: pgrole, Allowed: true}, nil + } +} + +// WithBasicAuthz returns an authorization function for Basic Auth +func PgBasicAuthz() AuthzFunc { + return func(ctx context.Context) (AuthzResponse, error) { + user, ok := ctx.Value(pgo.BasicAuthCtxKey).(string) + if !ok { + return AuthzResponse{Allowed: false}, nil + } + ctx = context.WithValue(ctx, pgo.PgRoleCtxKey, user) + return AuthzResponse{Role: user, Allowed: true}, nil + } +} + +// WithAnonAuthz returns an authorization function for anonymous users +func PgAnonAuthz() AuthzFunc { + return func(ctx context.Context) (AuthzResponse, error) { + pgrole := os.Getenv("PGO_POSTGRES_ANON_ROLE") + if pgrole == "" { + return AuthzResponse{Allowed: false}, nil + } + + // ctx = context.WithValue(ctx, pgo.PgRoleCtxKey, pgrole) + return AuthzResponse{Role: pgrole, Allowed: true}, nil + } +} diff --git a/middleware/postgres.go b/middleware/postgres.go new file mode 100644 index 0000000..7f3c317 --- /dev/null +++ b/middleware/postgres.go @@ -0,0 +1,131 @@ +package middleware + +import ( + "context" + "log" + "net/http" + "os" + "time" + + "github.com/edgeflare/pgo" + "github.com/jackc/pgx/v5/pgxpool" +) + +var ( + defaultPool *pgxpool.Pool +) + +// Postgres middleware attaches a connection from pool to the request context if the http request user is authorized. +func Postgres(config PgConfig, authorizers ...AuthzFunc) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + // lazy initialization of the default pool + if defaultPool == nil { + InitPgPool(&config) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + for _, authorize := range authorizers { + authzResponse, err := authorize(ctx) + if err != nil { + http.Error(w, "Authorization error", http.StatusInternalServerError) + return + } + if authzResponse.Allowed { + ctx = context.WithValue(ctx, pgo.PgRoleCtxKey, authzResponse.Role) + break + } + } + + if pgRole, ok := ctx.Value(pgo.PgRoleCtxKey).(string); ok { + // Acquire a connection from the default pool + conn, err := defaultPool.Acquire(r.Context()) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + // caller should + // defer conn.Release() + + // set the connection in the context + ctx = context.WithValue(ctx, pgo.PgConnCtxKey, conn) + ctx = context.WithValue(ctx, pgo.PgRoleCtxKey, pgRole) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + }) + } +} + +// PostgresConfig holds configuration for the Postgres connection pool +type PgConfig struct { + ConnString string `json:"conn_string"` + PoolConfig PgPoolConfig `json:"pool_config,omitempty"` +} + +// PgPoolConfig holds the configuration for the PostgreSQL connection pool. +type PgPoolConfig struct { + MaxConns int32 `json:"max_conns,omitempty"` + MinConns int32 `json:"min_conns,omitempty"` + MaxConnLifetime time.Duration `json:"max_conn_lifetime,omitempty"` + MaxConnIdleTime time.Duration `json:"max_conn_idle_time,omitempty"` + HealthCheckPeriod time.Duration `json:"health_check_period,omitempty"` +} + +// InitPgPool initializes the default PostgreSQL connection pool. +func InitPgPool(config *PgConfig) { + if config == nil { + config = defaultPgConfig() + } + + poolConfig, err := pgxpool.ParseConfig(config.ConnString) + if err != nil { + log.Fatal("Failed to parse connection string", err) + } + + if config.PoolConfig.MaxConns == 0 { + poolConfig.MaxConns = 10 + } + if config.PoolConfig.MinConns == 0 { + poolConfig.MinConns = 1 + } + if config.PoolConfig.MaxConnLifetime == 0 { + poolConfig.MaxConnLifetime = 30 * time.Minute + } + if config.PoolConfig.MaxConnIdleTime == 0 { + poolConfig.MaxConnIdleTime = 5 * time.Minute + } + if config.PoolConfig.HealthCheckPeriod == 0 { + poolConfig.HealthCheckPeriod = 1 * time.Minute + } + + defaultPool, err = pgxpool.NewWithConfig(context.Background(), poolConfig) + if err != nil { + log.Fatal("Unable to connect to database", err) + } + + // background goroutine to periodically check the health of the connections. + go func() { + for { + time.Sleep(poolConfig.HealthCheckPeriod) + err := defaultPool.Ping(context.Background()) + if err != nil { + log.Printf("Connection pool health check failed: %v", err) + } + } + }() +} + +// DefaultPgPool returns the default PostgreSQL connection pool. +func DefaultPool() *pgxpool.Pool { + return defaultPool +} + +func defaultPgConfig() *PgConfig { + return &PgConfig{ + ConnString: os.Getenv("PGO_POSTGRES_CONN_STRING"), + PoolConfig: PgPoolConfig{}, + } +} diff --git a/middleware/proxy.go b/middleware/proxy.go new file mode 100644 index 0000000..61b4086 --- /dev/null +++ b/middleware/proxy.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "crypto/tls" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +// Options holds the options for the proxy server +type Options struct { + TrimPrefix string + ForwardedHost string + TLSConfig *tls.Config +} + +// Serve creates a reverse proxy handler based on the given target and options +func Proxy(target string, opts Options) http.HandlerFunc { + // Parse the target URL + targetURL, err := url.Parse(target) + if err != nil { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "invalid target URL", http.StatusInternalServerError) + } + } + + // Set default options + if opts.ForwardedHost == "" { + opts.ForwardedHost = targetURL.Host + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + ServerName: targetURL.Hostname(), + } + } + + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // Configure the Director function + proxy.Director = func(req *http.Request) { + // Preserve the original request context + req.URL.Scheme = targetURL.Scheme + req.URL.Host = targetURL.Host + + // Update the Host header to the target host + req.Host = targetURL.Host + + // Trim prefix if provided + if opts.TrimPrefix != "" { + req.URL.Path = strings.TrimPrefix(req.URL.Path, opts.TrimPrefix) + } + + // Set the X-Forwarded-Host header if provided + if opts.ForwardedHost != "" { + req.Header.Set("X-Forwarded-Host", opts.ForwardedHost) + } + + } + + // Configure the proxy transport + proxy.Transport = &http.Transport{ + TLSClientConfig: opts.TLSConfig, + } + + return func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + } +} diff --git a/middleware/request_id.go b/middleware/request_id.go new file mode 100644 index 0000000..cf6c87d --- /dev/null +++ b/middleware/request_id.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/edgeflare/pgo" + "github.com/google/uuid" +) + +const RequestIDHeader = "X-Request-Id" + +// RequestID middleware generates a unique request ID and tracks request duration. +func RequestID(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if request ID is already set in the context + reqID, ok := r.Context().Value(pgo.RequestIDCtxKey).(string) + if !ok || reqID == "" { + reqID = uuid.New().String() + } + + ctx := r.Context() + // Not sure whether storing the request ID in the context is useful + // currently used by the logger middleware, but it can read from the request header set by this middleware + ctx = context.WithValue(ctx, pgo.RequestIDCtxKey, reqID) + w.Header().Set(RequestIDHeader, reqID) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go new file mode 100644 index 0000000..4df4c02 --- /dev/null +++ b/middleware/request_id_test.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/edgeflare/pgo" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestRequestID(t *testing.T) { + t.Run("should generate a new request ID if none exists", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqID := r.Context().Value(pgo.RequestIDCtxKey).(string) + _, err := uuid.Parse(reqID) + assert.NoError(t, err, "Request ID should be a valid UUID") + }) + + reqIDMiddleware := RequestID(handler) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + w := httptest.NewRecorder() + + reqIDMiddleware.ServeHTTP(w, req) + + resp := w.Result() + reqID := resp.Header.Get(RequestIDHeader) + _, err := uuid.Parse(reqID) + assert.NoError(t, err, "Response header X-Request-Id should be a valid UUID") + }) + + t.Run("should preserve existing request ID", func(t *testing.T) { + existingReqID := uuid.New().String() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqID := r.Context().Value(pgo.RequestIDCtxKey).(string) + assert.Equal(t, existingReqID, reqID, "Request ID should match the existing ID") + }) + + reqIDMiddleware := RequestID(handler) + + // Create a request with a pre-set context containing a request ID + ctx := context.WithValue(context.Background(), pgo.RequestIDCtxKey, existingReqID) + req := httptest.NewRequest("GET", "http://example.com/foo", nil).WithContext(ctx) + w := httptest.NewRecorder() + + reqIDMiddleware.ServeHTTP(w, req) + + resp := w.Result() + reqID := resp.Header.Get(RequestIDHeader) + assert.Equal(t, existingReqID, reqID, "Response header X-Request-Id should match the existing ID") + }) + + t.Run("should handle multiple requests independently", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqID := r.Context().Value(pgo.RequestIDCtxKey).(string) + w.Write([]byte(reqID)) + }) + + reqIDMiddleware := RequestID(handler) + + // Test first request + req1 := httptest.NewRequest("GET", "http://example.com/foo1", nil) + w1 := httptest.NewRecorder() + reqIDMiddleware.ServeHTTP(w1, req1) + _ = w1.Result() + body1 := w1.Body.String() + + // Test second request + req2 := httptest.NewRequest("GET", "http://example.com/foo2", nil) + w2 := httptest.NewRecorder() + reqIDMiddleware.ServeHTTP(w2, req2) + _ = w2.Result() + body2 := w2.Body.String() + + assert.NotEqual(t, body1, body2, "Request IDs should be different for different requests") + + // Additional validation for request IDs in responses + _, err1 := uuid.Parse(body1) + _, err2 := uuid.Parse(body2) + assert.NoError(t, err1, "Response body for first request should be a valid UUID") + assert.NoError(t, err2, "Response body for second request should be a valid UUID") + }) +} diff --git a/pkg/util/cert.go b/pkg/util/cert.go new file mode 100644 index 0000000..3f67968 --- /dev/null +++ b/pkg/util/cert.go @@ -0,0 +1,109 @@ +package util + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// LoadOrGenerateCert generates a self-signed certificate and private key if they do not exist at the specified paths. +// If the files already exist, they are loaded and returned. +func LoadOrGenerateCert(certPath, keyPath string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to generate private key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Self-Signed"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to create certificate: %v", err) + } + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + // Ensure the directory exists + err = os.MkdirAll(filepath.Dir(certPath), os.ModePerm) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to create tls directory: %v", err) + } + + // Write the cert and key to files + err = writeCert(certPath, derBytes) + if err != nil { + return tls.Certificate{}, err + } + + err = writeKey(keyPath, priv) + if err != nil { + return tls.Certificate{}, err + } + + return cert, nil +} + +// LoadCertFromFiles loads a TLS certificate from the specified paths. +func loadCertFromFiles(certPath, keyPath string) (tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to load TLS certificate: %v", err) + } + return cert, nil +} + +func writeCert(certPath string, derBytes []byte) error { + certOut, err := os.Create(certPath) + if err != nil { + return fmt.Errorf("failed to create cert file: %v", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err != nil { + return fmt.Errorf("failed to write certificate to file: %v", err) + } + return nil +} + +func writeKey(keyPath string, priv *ecdsa.PrivateKey) error { + keyOut, err := os.Create(keyPath) + if err != nil { + return fmt.Errorf("failed to create key file: %v", err) + } + defer keyOut.Close() + + privBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return fmt.Errorf("failed to marshal private key: %v", err) + } + + err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + if err != nil { + return fmt.Errorf("failed to write private key to file: %v", err) + } + return nil +} diff --git a/pkg/util/cert_test.go b/pkg/util/cert_test.go new file mode 100644 index 0000000..908f529 --- /dev/null +++ b/pkg/util/cert_test.go @@ -0,0 +1,131 @@ +package util + +import ( + "os" + "path/filepath" + "testing" +) + +// TestGenerateSelfSignedCert tests the LoadOrGenerateCert function for correctness. +func TestGenerateSelfSignedCert(t *testing.T) { + certPath := "./test_tls/tls.crt" + keyPath := "./test_tls/tls.key" + + // Clean up the test files after the test + defer func() { + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(filepath.Dir(certPath)) + }() + + cert, err := LoadOrGenerateCert(certPath, keyPath) + if err != nil { + t.Fatalf("LoadOrGenerateCert() failed: %v", err) + } + + // Check if the certificate and key files were created + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Errorf("Expected certificate file %s does not exist", certPath) + } + + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Expected key file %s does not exist", keyPath) + } + + // Validate the certificate structure + if len(cert.Certificate) == 0 { + t.Errorf("Generated certificate has no data") + } + + // Validate the private key structure + if cert.PrivateKey == nil { + t.Errorf("Generated certificate has no private key") + } +} + +// TestLoadCertFromFiles tests the LoadCertFromFiles function for correctness. +func TestLoadCertFromFiles(t *testing.T) { + certPath := "./test_tls/tls.crt" + keyPath := "./test_tls/tls.key" + + // Generate a self-signed certificate to be loaded later + _, err := LoadOrGenerateCert(certPath, keyPath) + if err != nil { + t.Fatalf("LoadOrGenerateCert() failed: %v", err) + } + + // Clean up the test files after the test + defer func() { + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(filepath.Dir(certPath)) + }() + + cert, err := loadCertFromFiles(certPath, keyPath) + if err != nil { + t.Fatalf("LoadCertFromFiles() failed: %v", err) + } + + // Validate the loaded certificate + if len(cert.Certificate) == 0 { + t.Errorf("Loaded certificate has no data") + } + + // Validate the private key structure + if cert.PrivateKey == nil { + t.Errorf("Loaded certificate has no private key") + } +} + +// TestGenerateSelfSignedCert_DirectoryCreation tests the directory creation for the certificate and key files. +func TestGenerateSelfSignedCert_DirectoryCreation(t *testing.T) { + certPath := "./test_tls_subdir/tls.crt" + keyPath := "./test_tls_subdir/tls.key" + + // Clean up the test files after the test + defer func() { + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(filepath.Dir(certPath)) + }() + + cert, err := LoadOrGenerateCert(certPath, keyPath) + if err != nil { + t.Fatalf("LoadOrGenerateCert() failed: %v", err) + } + + // Check if the directory was created + if _, err := os.Stat(filepath.Dir(certPath)); os.IsNotExist(err) { + t.Errorf("Expected directory %s does not exist", filepath.Dir(certPath)) + } + + // Check if the certificate and key files were created + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Errorf("Expected certificate file %s does not exist", certPath) + } + + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Expected key file %s does not exist", keyPath) + } + + // Validate the certificate structure + if len(cert.Certificate) == 0 { + t.Errorf("Generated certificate has no data") + } + + // Validate the private key structure + if cert.PrivateKey == nil { + t.Errorf("Generated certificate has no private key") + } +} + +// TestLoadCertFromFiles_InvalidPath tests loading certificates from invalid file paths. +func TestLoadCertFromFiles_InvalidPath(t *testing.T) { + invalidCertPath := "./test_tls/invalid.crt" + invalidKeyPath := "./test_tls/invalid.key" + + _, err := loadCertFromFiles(invalidCertPath, invalidKeyPath) + if err == nil { + t.Error("Expected error loading certificate from invalid paths, got nil") + } +} diff --git a/pkg/util/jq.go b/pkg/util/jq.go new file mode 100644 index 0000000..b92295f --- /dev/null +++ b/pkg/util/jq.go @@ -0,0 +1,45 @@ +package util + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +// Jq is a helper function to extract a value from a JSON-like map using a path +func Jq(input map[string]interface{}, path string) (string, error) { + path = strings.TrimPrefix(path, ".") + keys := strings.Split(path, ".") + var current interface{} = input + + for _, key := range keys { + if currentMap, ok := current.(map[string]interface{}); ok { + if strings.Contains(key, "[") && strings.Contains(key, "]") { + arrayKey := key[:strings.Index(key, "[")] + indexStr := key[strings.Index(key, "[")+1 : strings.Index(key, "]")] + index, err := strconv.Atoi(indexStr) + if err != nil { + return "", fmt.Errorf("invalid array index in path: %s", key) + } + if array, ok := currentMap[arrayKey].([]interface{}); ok { + if index < 0 || index >= len(array) { + return "", fmt.Errorf("index out of range in path: %s", key) + } + current = array[index] + } else { + return "", fmt.Errorf("expected array at path: %s", key) + } + } else { + current = currentMap[key] + } + } else { + return "", fmt.Errorf("expected map at path: %s", key) + } + } + + if role, ok := current.(string); ok { + return role, nil + } + return "", errors.New("role not found or not a string") +} diff --git a/pkg/util/rand/name.go b/pkg/util/rand/name.go new file mode 100644 index 0000000..10b838a --- /dev/null +++ b/pkg/util/rand/name.go @@ -0,0 +1,38 @@ +package rand + +import ( + mrand "math/rand" +) + +var adjectives = []string{ + "agile", "brave", "calm", "daring", "eager", + "fancy", "gentle", "happy", "intelligent", "jolly", + "kind", "lively", "mighty", "noble", "optimistic", + "playful", "quick", "radiant", "spirited", "trusty", + "upbeat", "vibrant", "wise", "youthful", "zealous", + "ambitious", "bright", "cheerful", "dynamic", "elegant", + "fearless", "graceful", "hopeful", "inspired", "jovial", + "keen", "loyal", "motivated", "nimble", "passionate", + "resourceful", "sturdy", "tenacious", "uplifted", "vigorous", + "warm", "xenial", "zesty", +} + +var birds = []string{ + "albatross", "bluebird", "canary", "dove", "eagle", + "falcon", "goldfinch", "hawk", "ibis", "jay", + "kingfisher", "lark", "magpie", "nightingale", "oriole", + "parrot", "quail", "robin", "sparrow", "toucan", + "umbrella bird", "vulture", "woodpecker", "xerus bird", "yellowhammer", + "zebra finch", "avocet", "bunting", "crane", "duck", + "egret", "flamingo", "goose", "heron", "indigo bunting", + "junco", "kestrel", "loon", "mockingbird", "nuthatch", + "owl", "pelican", "quaker parrot", "raven", "starling", + "tern", "vireo", "wren", "xantus's hummingbird", "yellowthroat", + "zebra dove", +} + +func NewName() string { + adj := adjectives[mrand.Intn(len(adjectives))] + bird := birds[mrand.Intn(len(birds))] + return adj + "-" + bird +} diff --git a/pkg/x/experiments.go b/pkg/x/experiments.go new file mode 100644 index 0000000..992fb9a --- /dev/null +++ b/pkg/x/experiments.go @@ -0,0 +1,3 @@ +package x + +// Experimental stuff diff --git a/pkg/x/logrepl/cdc.go b/pkg/x/logrepl/cdc.go new file mode 100644 index 0000000..b63fa67 --- /dev/null +++ b/pkg/x/logrepl/cdc.go @@ -0,0 +1,48 @@ +package logrepl + +import ( + "context" + "encoding/json" +) + +type Peer[T Event] interface { + Name() string + Connect(ctx context.Context, config interface{}) error + Disconnect(ctx context.Context) error + Write(ctx context.Context, data T) error + BatchWrite(ctx context.Context, data []T) error + Read(ctx context.Context) (<-chan T, error) + BatchRead(ctx context.Context, batchSize int) (<-chan []T, error) + Errors() <-chan error + Metadata() map[string]interface{} + PauseStreaming(ctx context.Context) error + ResumeStreaming(ctx context.Context) error +} + +type Event interface { + Type() string + Encode() ([]byte, error) + Decode([]byte) error +} + +type PostgresEvent struct { + Operation string `json:"operation"` + Schema string `json:"schema"` + Table string `json:"table"` + Data map[string]interface{} `json:"data,omitempty"` + OldData map[string]interface{} `json:"old_data,omitempty"` + Timestamp int64 `json:"timestamp"` + XID uint32 `json:"xid"` +} + +func (e PostgresEvent) Type() string { + return "PostgresEvent" +} + +func (e PostgresEvent) Encode() ([]byte, error) { + return json.Marshal(e) +} + +func (e *PostgresEvent) Decode(data []byte) error { + return json.Unmarshal(data, e) +} diff --git a/pkg/x/logrepl/main.go b/pkg/x/logrepl/main.go new file mode 100644 index 0000000..4f273c0 --- /dev/null +++ b/pkg/x/logrepl/main.go @@ -0,0 +1,150 @@ +package logrepl + +import ( + "context" + "os" + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "go.uber.org/zap" +) + +func Run(ctx context.Context, p Peer[Event]) { + log, _ := zap.NewProduction() + conn, sysident, err := SetupReplication() + if err != nil { + log.Fatal("SetupReplication failed", zap.Error(err)) + } + defer conn.Close(context.Background()) + + pluginArguments := []string{ + "proto_version '2'", + "publication_names 'pglogrepl_demo'", + "messages 'true'", + "streaming 'true'", + } + + err = StartReplication(conn, sysident, pluginArguments) + if err != nil { + log.Fatal("StartReplication failed", zap.Error(err)) + } + + clientXLogPos := sysident.XLogPos + standbyMessageTimeout := standbyMessageTimer + nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) + relations := map[uint32]*pglogrepl.RelationMessage{} + relationsV2 := map[uint32]*pglogrepl.RelationMessageV2{} + typeMap := pgtype.NewMap() + inStream := false + + go func() { + for { + select { + case err := <-p.Errors(): + log.Error("Error", zap.Error(err)) + case <-ctx.Done(): + log.Info("Shutting down error listener") + return + } + } + }() + + for { + select { + case <-ctx.Done(): + log.Info("Received shutdown signal, exiting replication loop") + return + default: + if time.Now().After(nextStandbyMessageDeadline) { + err = pglogrepl.SendStandbyStatusUpdate(context.Background(), conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: clientXLogPos}) + if err != nil { + log.Fatal("SendStandbyStatusUpdate failed", zap.Error(err)) + } + log.Info("Sent Standby status message", zap.String("position", clientXLogPos.String())) + nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) + } + + ctx, cancel := context.WithDeadline(context.Background(), nextStandbyMessageDeadline) + rawMsg, err := conn.ReceiveMessage(ctx) + cancel() + if err != nil { + if pgconn.Timeout(err) { + continue + } + log.Fatal("ReceiveMessage failed", zap.Error(err)) + } + + if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { + log.Fatal("received Postgres WAL error", zap.Any("error", errMsg)) + } + + msg, ok := rawMsg.(*pgproto3.CopyData) + if !ok { + log.Info("Received unexpected message", zap.Any("message", rawMsg)) + continue + } + + switch msg.Data[0] { + case pglogrepl.PrimaryKeepaliveMessageByteID: + handlePrimaryKeepaliveMessage(msg.Data[1:], &clientXLogPos, &nextStandbyMessageDeadline, log) + + case pglogrepl.XLogDataByteID: + xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) + if err != nil { + log.Fatal("ParseXLogData failed", zap.Error(err)) + } + + if os.Getenv("PGO_POSTGRES_LOGREPL_OUTPUT_PLUGIN") == "wal2json" { + log.Info("wal2json data", zap.String("data", string(xld.WALData))) + } else { + log.Info("XLogData", + zap.String("WALStart", xld.WALStart.String()), + zap.String("ServerWALEnd", xld.ServerWALEnd.String()), + zap.Time("ServerTime", xld.ServerTime)) + + var events []Event + if len(relationsV2) > 0 { + events = processV2(xld.WALData, relationsV2, typeMap, &inStream) + } else { + events = processV1(xld.WALData, relations, typeMap) + } + + for _, event := range events { + if err := p.Write(context.Background(), event); err != nil { + log.Error("Error writing event", zap.Error(err)) + } + } + } + + if xld.WALStart > clientXLogPos { + clientXLogPos = xld.WALStart + } + } + } + } +} + +func handlePrimaryKeepaliveMessage(data []byte, clientXLogPos *pglogrepl.LSN, nextStandbyMessageDeadline *time.Time, log *zap.Logger) { + pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(data) + if err != nil { + log.Fatal("ParsePrimaryKeepaliveMessage failed", zap.Error(err)) + } + log.Info("Primary Keepalive Message", + zap.String("ServerWALEnd", pkm.ServerWALEnd.String()), + zap.Time("ServerTime", pkm.ServerTime), + zap.Bool("ReplyRequested", pkm.ReplyRequested)) + if pkm.ServerWALEnd > *clientXLogPos { + *clientXLogPos = pkm.ServerWALEnd + } + if pkm.ReplyRequested { + *nextStandbyMessageDeadline = time.Time{} + } +} + +func StartReplication(conn *pgconn.PgConn, sysident pglogrepl.IdentifySystemResult, pluginArguments []string) error { + err := pglogrepl.StartReplication(context.Background(), conn, slotName, sysident.XLogPos, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) + return err +} diff --git a/pkg/x/logrepl/pglogrepl.go b/pkg/x/logrepl/pglogrepl.go new file mode 100644 index 0000000..710efc3 --- /dev/null +++ b/pkg/x/logrepl/pglogrepl.go @@ -0,0 +1,106 @@ +package logrepl + +import ( + "context" + "fmt" + "log" + "os" + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" +) + +const ( + outputPlugin = "pgoutput" + publicationName = "pglogrepl_demo" + slotName = "pglogrepl_demo" + standbyMessageTimer = 10 * time.Second +) + +func SetupReplication() (*pgconn.PgConn, pglogrepl.IdentifySystemResult, error) { + connString := os.Getenv("PGO_PGLOGREPL_CONN_STRING") + conn, err := pgconn.Connect(context.Background(), connString) + if err != nil { + return nil, pglogrepl.IdentifySystemResult{}, err + } + + exists, err := checkPublicationExists(conn, publicationName) + if err != nil { + conn.Close(context.Background()) + return nil, pglogrepl.IdentifySystemResult{}, err + } + if !exists { + err = createPublication(conn, publicationName) + if err != nil { + conn.Close(context.Background()) + return nil, pglogrepl.IdentifySystemResult{}, err + } + log.Println("Created publication", publicationName) + } else { + log.Println("Publication", publicationName, "already exists") + } + + sysident, err := pglogrepl.IdentifySystem(context.Background(), conn) + if err != nil { + conn.Close(context.Background()) + return nil, pglogrepl.IdentifySystemResult{}, err + } + log.Println("SystemID:", sysident.SystemID, "Timeline:", sysident.Timeline, "XLogPos:", sysident.XLogPos, "DBName:", sysident.DBName) + + slotExists, err := checkSlotExists(conn, slotName) + if err != nil { + conn.Close(context.Background()) + return nil, pglogrepl.IdentifySystemResult{}, err + } + if !slotExists { + err = createReplicationSlot(conn, slotName, outputPlugin) + if err != nil { + conn.Close(context.Background()) + return nil, pglogrepl.IdentifySystemResult{}, err + } + log.Println("Created replication slot", slotName) + } else { + log.Println("Replication slot", slotName, "already exists") + } + + return conn, sysident, nil +} + +func checkPublicationExists(conn *pgconn.PgConn, publicationName string) (bool, error) { + query := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM pg_publication WHERE pubname = '%s');", publicationName) + result := conn.Exec(context.Background(), query) + rows, err := result.ReadAll() + if err != nil { + return false, err + } + if len(rows) > 0 && len(rows[0].Rows) > 0 { + return string(rows[0].Rows[0][0]) == "t", nil + } + return false, nil +} + +func createPublication(conn *pgconn.PgConn, publicationName string) error { + query := fmt.Sprintf("CREATE PUBLICATION %s FOR ALL TABLES;", publicationName) + result := conn.Exec(context.Background(), query) + _, err := result.ReadAll() + return err +} + +func checkSlotExists(conn *pgconn.PgConn, slotName string) (bool, error) { + query := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = '%s');", slotName) + result := conn.Exec(context.Background(), query) + rows, err := result.ReadAll() + if err != nil { + return false, err + } + if len(rows) > 0 && len(rows[0].Rows) > 0 { + return string(rows[0].Rows[0][0]) == "t", nil + } + return false, nil +} + +func createReplicationSlot(conn *pgconn.PgConn, slotName string, outputPlugin string) error { + _, err := pglogrepl.CreateReplicationSlot(context.Background(), conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: false}) + return err +} diff --git a/pkg/x/logrepl/postgres_peer.go b/pkg/x/logrepl/postgres_peer.go new file mode 100644 index 0000000..246dcaa --- /dev/null +++ b/pkg/x/logrepl/postgres_peer.go @@ -0,0 +1,65 @@ +package logrepl + +import ( + "context" +) + +type PostgresPeer struct { + EventChan chan Event + errorChan chan error +} + +func NewPostgresPeer() *PostgresPeer { + return &PostgresPeer{ + EventChan: make(chan Event), + errorChan: make(chan error), + } +} + +func (p *PostgresPeer) Name() string { + return "PostgresPeer" +} + +func (p *PostgresPeer) Connect(ctx context.Context, config interface{}) error { + return nil +} + +func (p *PostgresPeer) Disconnect(ctx context.Context) error { + return nil +} + +func (p *PostgresPeer) Write(ctx context.Context, data Event) error { + p.EventChan <- data + return nil +} + +func (p *PostgresPeer) BatchWrite(ctx context.Context, data []Event) error { + for _, event := range data { + p.EventChan <- event + } + return nil +} + +func (p *PostgresPeer) Read(ctx context.Context) (<-chan Event, error) { + return p.EventChan, nil +} + +func (p *PostgresPeer) BatchRead(ctx context.Context, batchSize int) (<-chan []Event, error) { + return nil, nil +} + +func (p *PostgresPeer) Errors() <-chan error { + return p.errorChan +} + +func (p *PostgresPeer) Metadata() map[string]interface{} { + return nil +} + +func (p *PostgresPeer) PauseStreaming(ctx context.Context) error { + return nil +} + +func (p *PostgresPeer) ResumeStreaming(ctx context.Context) error { + return nil +} diff --git a/pkg/x/logrepl/process_v1.go b/pkg/x/logrepl/process_v1.go new file mode 100644 index 0000000..e427341 --- /dev/null +++ b/pkg/x/logrepl/process_v1.go @@ -0,0 +1,157 @@ +package logrepl + +import ( + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgtype" +) + +func processV1(walData []byte, relations map[uint32]*pglogrepl.RelationMessage, typeMap *pgtype.Map) []Event { + logicalMsg, err := pglogrepl.Parse(walData) + if err != nil { + panic("Parse logical replication message: " + err.Error()) + } + var events []Event + switch logicalMsg := logicalMsg.(type) { + case *pglogrepl.RelationMessage: + relations[logicalMsg.RelationID] = logicalMsg + + case *pglogrepl.BeginMessage: + // Handle begin message + + case *pglogrepl.CommitMessage: + // Handle commit message + + case *pglogrepl.InsertMessage: + events = append(events, handleInsertMessageV1(logicalMsg, relations, typeMap)) + + case *pglogrepl.UpdateMessage: + events = append(events, handleUpdateMessageV1(logicalMsg, relations, typeMap)) + + case *pglogrepl.DeleteMessage: + events = append(events, handleDeleteMessageV1(logicalMsg, relations, typeMap)) + + case *pglogrepl.TruncateMessage: + events = append(events, handleTruncateMessageV1(logicalMsg, relations)) + + case *pglogrepl.TypeMessage: + case *pglogrepl.OriginMessage: + + case *pglogrepl.LogicalDecodingMessage: + // Handle logical decoding message + } + + return events +} + +func handleInsertMessageV1(msg *pglogrepl.InsertMessage, relations map[uint32]*pglogrepl.RelationMessage, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + values := map[string]interface{}{} + for idx, col := range msg.Tuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + values[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + values[colName] = val + } + } + return &PostgresEvent{ + Operation: "INSERT", + Schema: rel.Namespace, + Table: rel.RelationName, + Data: values, + Timestamp: time.Now().Unix(), + // XID: uint32(msg.Type()), + } +} + +func handleUpdateMessageV1(msg *pglogrepl.UpdateMessage, relations map[uint32]*pglogrepl.RelationMessage, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + newValues := map[string]interface{}{} + for idx, col := range msg.NewTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + newValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + newValues[colName] = val + } + } + oldValues := map[string]interface{}{} + for idx, col := range msg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + oldValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + oldValues[colName] = val + } + } + return &PostgresEvent{ + Operation: "UPDATE", + Schema: rel.Namespace, + Table: rel.RelationName, + Data: newValues, + OldData: oldValues, + Timestamp: time.Now().Unix(), + // XID: msg.Xid, + } +} + +func handleDeleteMessageV1(msg *pglogrepl.DeleteMessage, relations map[uint32]*pglogrepl.RelationMessage, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + oldValues := map[string]interface{}{} + for idx, col := range msg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + oldValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + oldValues[colName] = val + } + } + return &PostgresEvent{ + Operation: "DELETE", + Schema: rel.Namespace, + Table: rel.RelationName, + OldData: oldValues, + Timestamp: time.Now().Unix(), + // XID: msg.Xid, + } +} + +func handleTruncateMessageV1(msg *pglogrepl.TruncateMessage, relations map[uint32]*pglogrepl.RelationMessage) Event { + // Implement truncate message handling + return &PostgresEvent{} +} diff --git a/pkg/x/logrepl/process_v2.go b/pkg/x/logrepl/process_v2.go new file mode 100644 index 0000000..53640e7 --- /dev/null +++ b/pkg/x/logrepl/process_v2.go @@ -0,0 +1,165 @@ +package logrepl + +import ( + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgtype" + "go.uber.org/zap" +) + +func processV2(walData []byte, relations map[uint32]*pglogrepl.RelationMessageV2, typeMap *pgtype.Map, inStream *bool) []Event { + logicalMsg, err := pglogrepl.ParseV2(walData, *inStream) + if err != nil { + zap.L().Fatal("ParseV2 failed", zap.Error(err)) + } + var events []Event + switch logicalMsg := logicalMsg.(type) { + case *pglogrepl.RelationMessageV2: + relations[logicalMsg.RelationID] = logicalMsg + + case *pglogrepl.BeginMessage: + // Handle begin message + + case *pglogrepl.CommitMessage: + // Handle commit message + + case *pglogrepl.InsertMessageV2: + events = append(events, handleInsertMessageV2(logicalMsg, relations, typeMap)) + + case *pglogrepl.UpdateMessageV2: + events = append(events, handleUpdateMessageV2(logicalMsg, relations, typeMap)) + + case *pglogrepl.DeleteMessageV2: + events = append(events, handleDeleteMessageV2(logicalMsg, relations, typeMap)) + + case *pglogrepl.TruncateMessageV2: + events = append(events, handleTruncateMessageV2(logicalMsg, relations)) + + case *pglogrepl.TypeMessageV2: + case *pglogrepl.OriginMessage: + case *pglogrepl.LogicalDecodingMessageV2: + // Handle logical decoding message + case *pglogrepl.StreamStartMessageV2: + *inStream = true + case *pglogrepl.StreamStopMessageV2: + *inStream = false + case *pglogrepl.StreamCommitMessageV2: + case *pglogrepl.StreamAbortMessageV2: + default: + zap.L().Warn("Unknown message type in pgoutput stream", zap.Any("message", logicalMsg)) + } + + return events +} + +func handleInsertMessageV2(msg *pglogrepl.InsertMessageV2, relations map[uint32]*pglogrepl.RelationMessageV2, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + values := map[string]interface{}{} + for idx, col := range msg.Tuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + values[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + values[colName] = val + } + } + return &PostgresEvent{ + Operation: "INSERT", + Schema: rel.Namespace, + Table: rel.RelationName, + Data: values, + Timestamp: time.Now().Unix(), + XID: msg.Xid, + } +} + +func handleUpdateMessageV2(msg *pglogrepl.UpdateMessageV2, relations map[uint32]*pglogrepl.RelationMessageV2, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + newValues := map[string]interface{}{} + for idx, col := range msg.NewTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + newValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + newValues[colName] = val + } + } + oldValues := map[string]interface{}{} + for idx, col := range msg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + oldValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + oldValues[colName] = val + } + } + return &PostgresEvent{ + Operation: "UPDATE", + Schema: rel.Namespace, + Table: rel.RelationName, + Data: newValues, + OldData: oldValues, + Timestamp: time.Now().Unix(), + XID: msg.Xid, + } +} + +func handleDeleteMessageV2(msg *pglogrepl.DeleteMessageV2, relations map[uint32]*pglogrepl.RelationMessageV2, typeMap *pgtype.Map) Event { + rel, ok := relations[msg.RelationID] + if !ok { + panic("unknown relation ID " + string(msg.RelationID)) + } + oldValues := map[string]interface{}{} + for idx, col := range msg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': + oldValues[colName] = nil + case 'u': + case 't': + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + panic("error decoding column data: " + err.Error()) + } + oldValues[colName] = val + } + } + return &PostgresEvent{ + Operation: "DELETE", + Schema: rel.Namespace, + Table: rel.RelationName, + OldData: oldValues, + Timestamp: time.Now().Unix(), + XID: msg.Xid, + } +} + +func handleTruncateMessageV2(msg *pglogrepl.TruncateMessageV2, relations map[uint32]*pglogrepl.RelationMessageV2) Event { + // Implement truncate message handling + return &PostgresEvent{} +} diff --git a/pkg/x/logrepl/util.go b/pkg/x/logrepl/util.go new file mode 100644 index 0000000..7956e69 --- /dev/null +++ b/pkg/x/logrepl/util.go @@ -0,0 +1,17 @@ +package logrepl + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// decodeTextColumnData decodes the text format of a column's data into a suitable Go type based on its data type. +func decodeTextColumnData(typeMap *pgtype.Map, data []byte, dataType uint32) (interface{}, error) { + if dt, ok := typeMap.TypeForOID(dataType); ok { + value, err := dt.Codec.DecodeValue(typeMap, dataType, pgtype.TextFormatCode, data) + if err != nil { + return nil, err + } + return value, nil + } + return string(data), nil +} diff --git a/pkg/x/pgcache/pgcache.go b/pkg/x/pgcache/pgcache.go new file mode 100644 index 0000000..765eea1 --- /dev/null +++ b/pkg/x/pgcache/pgcache.go @@ -0,0 +1,102 @@ +package pgcache + +import ( + "fmt" + "io" + "log" + "net" + + pg_query "github.com/pganalyze/pg_query_go/v5" +) + +func main() { + listener, err := net.Listen("tcp", "localhost:5430") + if err != nil { + log.Fatalf("Error starting TCP server: %v", err) + } + defer listener.Close() + log.Println("Server is listening on localhost:5430") + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Error accepting connection: %v", err) + continue + } + go handleClient(conn) + } +} + +func handleClient(clientConn net.Conn) { + defer clientConn.Close() + + serverConn, err := net.Dial("tcp", "localhost:5443") + if err != nil { + log.Printf("Error connecting to PostgreSQL server: %v", err) + return + } + defer serverConn.Close() + + go func() { + // Forward responses from PostgreSQL server to the client + if _, err := io.Copy(clientConn, serverConn); err != nil { + log.Printf("Error forwarding response to client: %v", err) + } + }() + + // Read and parse client queries + buf := make([]byte, 4096) + for { + n, err := clientConn.Read(buf) + if err != nil { + if err != io.EOF { + log.Printf("Error reading from client: %v", err) + } + break + } + + query := string(buf[:n]) + log.Printf("Received data: %s", query) + + // Skip non-SQL messages (like startup messages) + // if !isSQLQuery(query) { + // log.Printf("Skipping non-SQL message: %s", query) + // // Forward the message to the PostgreSQL server + // if _, err := serverConn.Write(buf[:n]); err != nil { + // log.Printf("Error forwarding message to PostgreSQL server: %v", err) + // break + // } + // continue + // } + + // Parse and log metrics + parseQuery(query) + + // Forward the query to the PostgreSQL server + if _, err := serverConn.Write(buf[:n]); err != nil { + log.Printf("Error forwarding query to PostgreSQL server: %v", err) + break + } + } +} + +// func isSQLQuery(query string) bool { +// // Simple heuristic to check if the message is a SQL query +// query = strings.TrimSpace(query) +// if len(query) == 0 { +// return false +// } + +// firstChar := query[0] +// // Check if the first character is likely the start of a SQL command +// return firstChar == 'S' || firstChar == 'I' || firstChar == 'U' || firstChar == 'D' || firstChar == 's' || firstChar == 'i' || firstChar == 'u' || firstChar == 'd' +// } + +func parseQuery(query string) { + tree, err := pg_query.ParseToJSON(query) + if err != nil { + fmt.Println(err) + return + } + fmt.Printf("\n%+v\n", tree) +} diff --git a/pkg/x/pgproxy/main.go b/pkg/x/pgproxy/main.go new file mode 100644 index 0000000..8659103 --- /dev/null +++ b/pkg/x/pgproxy/main.go @@ -0,0 +1,51 @@ +package pgproxy + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" +) + +var options struct { + listenAddress string + responseCommand string +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address") + flag.StringVar(&options.responseCommand, "response-command", "echo 'fortune | cowsay -f elephant'", "Command to execute to generate query response") + flag.Parse() + + ln, err := net.Listen("tcp", options.listenAddress) + if err != nil { + log.Fatal(err) + } + log.Println("Listening on", ln.Addr()) + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + log.Println("Accepted connection from", conn.RemoteAddr()) + + b := NewPgFortuneBackend(conn, func() ([]byte, error) { + return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() + }) + go func() { + err := b.Run() + if err != nil { + log.Println(err) + } + log.Println("Closed connection from", conn.RemoteAddr()) + }() + } +} diff --git a/pkg/x/pgproxy/server.go b/pkg/x/pgproxy/server.go new file mode 100644 index 0000000..e1d1a81 --- /dev/null +++ b/pkg/x/pgproxy/server.go @@ -0,0 +1,113 @@ +package pgproxy + +import ( + "fmt" + "net" + + "github.com/jackc/pgx/v5/pgproto3" +) + +type PgFortuneBackend struct { + backend *pgproto3.Backend + conn net.Conn + responder func() ([]byte, error) +} + +func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend { + backend := pgproto3.NewBackend(conn, conn) + + connHandler := &PgFortuneBackend{ + backend: backend, + conn: conn, + responder: responder, + } + + return connHandler +} + +func (p *PgFortuneBackend) Run() error { + defer p.Close() + + err := p.handleStartup() + if err != nil { + return err + } + + for { + msg, err := p.backend.Receive() + if err != nil { + return fmt.Errorf("error receiving message: %w", err) + } + + switch msg.(type) { + case *pgproto3.Query: + response, err := p.responder() + if err != nil { + return fmt.Errorf("error generating query response: %w", err) + } + + buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + { + Name: []byte("fortune"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 25, + DataTypeSize: -1, + TypeModifier: -1, + Format: 0, + }, + }}).Encode(nil)) + buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)) + buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error writing query response: %w", err) + } + case *pgproto3.Terminate: + return nil + default: + return fmt.Errorf("received message other than Query from client: %#v", msg) + } + } +} + +func (p *PgFortuneBackend) handleStartup() error { + startupMessage, err := p.backend.ReceiveStartupMessage() + if err != nil { + return fmt.Errorf("error receiving startup message: %w", err) + } + + fmt.Printf("Received startup message: %#v\n", startupMessage) + + switch startupMessage.(type) { + case *pgproto3.StartupMessage: + buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error sending ready for query: %w", err) + } + case *pgproto3.SSLRequest: + _, err = p.conn.Write([]byte("N")) + if err != nil { + return fmt.Errorf("error sending deny SSL request: %w", err) + } + return p.handleStartup() + default: + return fmt.Errorf("unknown startup message: %#v", startupMessage) + } + + return nil +} + +func (p *PgFortuneBackend) Close() error { + return p.conn.Close() +} + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/postgres.go b/postgres.go new file mode 100644 index 0000000..8475ee8 --- /dev/null +++ b/postgres.go @@ -0,0 +1,350 @@ +package pgo + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "reflect" + "strings" + "time" + + "github.com/edgeflare/pgxutil" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +var ErrTooManyRows = fmt.Errorf("too many rows") + +// Conn retrieves the OIDC user and a pgxpool.Conn from the request context. +// It returns an error if the user or connection is not found in the context. +// Currently it only supports OIDC users. But the authZ middleware chain works, and error occurs here. +func Conn(r *http.Request) (*oidc.IntrospectionResponse, *pgxpool.Conn, *pgconn.PgError) { + // TODO: Add support for Basic Auth + // basicAuthUser := r.Context().Value(pgo.BasicAuthCtxKey).(string) + user, ok := OIDCUser(r) + if !ok { + return nil, nil, &pgconn.PgError{ + Code: "28000", // SQLSTATE for invalid authorization specification + Message: "User not found in context", + } + } + + conn, ok := r.Context().Value(PgConnCtxKey).(*pgxpool.Conn) + if !ok || conn == nil { + return nil, nil, &pgconn.PgError{ + Code: "08003", // SQLSTATE for connection does not exist + Message: "Failed to get connection from context", + } + } + + return user, conn, nil +} + +// ConnWithRole returns the OIDC user and a pgxpool.Conn from the request context. +func ConnWithRole(r *http.Request) (*oidc.IntrospectionResponse, *pgxpool.Conn, *pgconn.PgError) { + user, conn, pgErr := Conn(r) + if pgErr != nil { + return nil, nil, pgErr + } + + role, ok := r.Context().Value(PgRoleCtxKey).(string) + if !ok { + return nil, nil, &pgconn.PgError{ + Code: "28000", + Message: "Role not found in context", + } + } + + claimsJSON, err := json.Marshal(user.Claims) + if err != nil { + return nil, nil, &pgconn.PgError{ + Code: "28000", + Message: fmt.Sprintf("Failed to marshal claims: %v", err), + } + } + escapedClaimsJSON := strings.ReplaceAll(string(claimsJSON), "'", "''") + + setRoleQuery := fmt.Sprintf("SET ROLE %s;", role) + setReqClaimsQuery := fmt.Sprintf("SET request.oidc.claims TO '%s';", escapedClaimsJSON) + combinedQuery := setRoleQuery + setReqClaimsQuery + + _, execErr := conn.Exec(context.Background(), combinedQuery) + if execErr != nil { + conn.Release() + if pgErr, ok := execErr.(*pgconn.PgError); ok { + return nil, nil, pgErr + } + return nil, nil, &pgconn.PgError{ + Code: "P0000", // Generic SQLSTATE code + Message: "Failed to set role and claims", + } + } + + return user, conn, nil +} + +// Select executes a SELECT query and returns the results as a slice of T +func Select[T any](r *http.Request, query string, args []any, scanFn pgx.RowToFunc[T], keepConn ...bool) ([]T, *pgconn.PgError) { + _, conn, connErr := ConnWithRole(r) + if connErr != nil { + return nil, connErr + } + + if len(keepConn) == 0 || !keepConn[0] { + defer conn.Release() + } + + rows, err := pgxutil.Select(r.Context(), conn, query, args, scanFn) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok { + return nil, pgErr + } + return nil, &pgconn.PgError{ + Code: "P0000", + Message: "Failed to collect rows", + } + } + + return rows, nil +} + +// SelectAndRespondJSON executes a SELECT query and responds with the results as JSON +func SelectAndRespondJSON[T any](w http.ResponseWriter, r *http.Request, query string, args []any, scanFn pgx.RowToFunc[T]) { + results, pgErr := Select[T](r, query, args, scanFn) + if pgErr != nil { + RespondError(w, PgErrorCodeToHTTPStatus(pgErr.Code), pgErr.Message) + return + } + RespondJSON(w, http.StatusOK, results) +} + +// SelectRow executes sql with args on db and returns the T produced by scanFn. The query should return one row. If no +// rows are found returns an error where errors.Is(pgx.ErrNoRows) is true. Returns an error if more than one row is returned. +func SelectRow[T any](r *http.Request, query string, args []any, scanFn pgx.RowToFunc[T], keepConn ...bool) (*T, *pgconn.PgError) { + _, conn, err := ConnWithRole(r) + if err != nil { + return nil, err + } + + if len(keepConn) == 0 || !keepConn[0] { + defer conn.Release() + } + + row, pgErr := pgxutil.SelectRow(r.Context(), conn, query, args, scanFn) + if pgErr != nil { + if pgErr, ok := pgErr.(*pgconn.PgError); ok { + return nil, pgErr + } + return nil, &pgconn.PgError{ + Code: "P0000", // Generic SQLSTATE code + Message: "Failed to collect rows", + } + } + return &row, nil +} + +// SelectRowAndRespondJSON executes a SELECT query and responds with the results as JSON +// It responds with an error if the query or parsing the results fails. +func SelectRowAndRespondJSON[T any](w http.ResponseWriter, r *http.Request, query string, args []any, scanFn pgx.RowToFunc[T]) { + result, err := SelectRow(r, query, args, scanFn) + if err != nil { + if err == ErrTooManyRows { + RespondError(w, http.StatusInternalServerError, "too many rows returned") + return + } + RespondError(w, http.StatusInternalServerError, err.Error()) + return + } + RespondJSON(w, http.StatusOK, result) +} + +// InsertRow inserts a row into the specified table +func InsertRow(r *http.Request, tableName any, row map[string]any, keepConn ...bool) (pgconn.CommandTag, *pgconn.PgError) { + _, conn, pgErr := ConnWithRole(r) + if pgErr != nil { + return pgconn.CommandTag{}, pgErr + } + + if len(keepConn) == 0 || !keepConn[0] { + defer conn.Release() + } + + err := pgxutil.InsertRow(r.Context(), conn, tableName, row) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok { + return pgconn.CommandTag{}, pgErr + } + return pgconn.CommandTag{}, &pgconn.PgError{ + Code: "P0000", // Generic SQLSTATE code + // Message: "Failed to execute query", + Message: err.Error(), + } + } + + return pgconn.CommandTag{}, nil +} + +// InsertRowAndRespondJSON inserts a row into the specified table and responds with the results as JSON +func InsertRowAndRespondJSON(w http.ResponseWriter, r *http.Request, tableName any, row map[string]any) { + cmdTag, pgErr := InsertRow(r, tableName, row) + if pgErr != nil { + RespondError(w, PgErrorCodeToHTTPStatus(pgErr.Code), pgErr.Message) + return + } + RespondJSON(w, http.StatusCreated, cmdTag) +} + +// UpdateRow updates a row in the specified table +func Update(r *http.Request, tableName any, row map[string]any, setValues, whereValues map[string]any, keepConn ...bool) (pgconn.CommandTag, *pgconn.PgError) { + _, conn, pgErr := ConnWithRole(r) + if pgErr != nil { + return pgconn.CommandTag{}, pgErr + } + + if len(keepConn) == 0 || !keepConn[0] { + defer conn.Release() + } + + cmdTag, err := pgxutil.Update(r.Context(), conn, tableName, setValues, whereValues) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok { + return pgconn.CommandTag{}, pgErr + } + return pgconn.CommandTag{}, &pgconn.PgError{ + Code: "P0000", // Generic SQLSTATE code + // Message: "Failed to execute query", + Message: err.Error(), + } + } + return cmdTag, nil +} + +// UpdateAndRespondJSON updates a row in the specified table and responds with the results as JSON +func UpdateAndRespondJSON(w http.ResponseWriter, r *http.Request, tableName any, row map[string]any, setValues, whereValues map[string]any) { + ct, pgErr := Update(r, tableName, row, setValues, whereValues) + if pgErr != nil { + RespondError(w, PgErrorCodeToHTTPStatus(pgErr.Code), pgErr.Message) + return + } + RespondJSON(w, http.StatusOK, ct) +} + +// ExecRow executes SQL with args. It returns an error unless exactly one row is affected. +func ExecRow(r *http.Request, sql string, args []any, keepConn ...bool) (pgconn.CommandTag, error) { + _, conn, pgErr := ConnWithRole(r) + if pgErr != nil { + return pgconn.CommandTag{}, pgErr + } + if len(keepConn) == 0 || !keepConn[0] { + defer conn.Release() + } + + ct, err := pgxutil.ExecRow(r.Context(), conn, sql, args...) + if err != nil { + fmt.Println(err) + return ct, err + } + return ct, nil +} + +// ExecRowAndRespondJSON executes SQL with args and responds with the results as JSON +func ExecRowAndRespondJSON(w http.ResponseWriter, r *http.Request, sql string, args ...any) { + ct, err := ExecRow(r, sql, args) + if err != nil { + RespondError(w, http.StatusInternalServerError, err.Error()) + return + } + RespondJSON(w, http.StatusOK, ct) +} + +// RowMap takes a struct and returns a map with non-zero fields +func RowMap(v interface{}) map[string]interface{} { + result := make(map[string]interface{}) + rv := reflect.ValueOf(v) + rt := reflect.TypeOf(v) + + // Ensure the input is a struct + if rv.Kind() != reflect.Struct { + panic("RowMap: expected a struct") + } + + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + value := rv.Field(i) + + // Skip unexported fields + if field.PkgPath != "" { + continue + } + + // Check if the field has a JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + jsonTag = field.Name + } else { + jsonTag = strings.Split(jsonTag, ",")[0] + } + + // Check if the value is zero + if isZero(value) { + continue + } + + // Add to result map + result[jsonTag] = value.Interface() + } + + return result +} + +// Checks if a reflect.Value is zero +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return v.String() == "" + case reflect.Slice, reflect.Map, reflect.Chan: + return v.Len() == 0 + case reflect.Struct: + // Handle time.Time zero value + if v.Type() == reflect.TypeOf(time.Time{}) { + return v.Interface().(time.Time).IsZero() + } + // Handle uuid.UUID zero value + if v.Type() == reflect.TypeOf(uuid.UUID{}) { + return v.Interface().(uuid.UUID) == uuid.Nil + } + // Handle pgtype.CIDR zero value + if v.Type() == reflect.TypeOf(pgtype.InetCodec{}) { + cidr := v.Interface().(pgtype.Type) + return cidr.OID != pgtype.InetOID + } + case reflect.Interface, reflect.Ptr: + return v.IsNil() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + } + return false +} + +// MapPgErrorCodeToHTTPStatus maps PostgreSQL error codes to HTTP status codes +func PgErrorCodeToHTTPStatus(code string) int { + switch code { + case "28000": // Invalid authorization specification + return http.StatusUnauthorized + case "08003": // Connection does not exist + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..6cd4479 --- /dev/null +++ b/router.go @@ -0,0 +1,172 @@ +package pgo + +import ( + "context" + "crypto/tls" + "fmt" + + "log" + "net/http" + "strings" + "sync" + + "github.com/edgeflare/pgo/pkg/util" +) + +// Middleware defines a function type that represents a middleware. Middleware functions wrap an +// http.Handler to modify or enhance its behavior. +type Middleware func(http.Handler) http.Handler + +// RouterOptions is a function type that represents options to configure a Router. +type RouterOptions func(*Router) + +// Router is the main structure for handling HTTP routing and middleware. +type Router struct { + mux *http.ServeMux + middleware []Middleware + server *http.Server + prefix string + mu sync.RWMutex // Mutex for concurrency safety + +} + +// NewRouter creates a new instance of Router with the given options. +func NewRouter(opts ...RouterOptions) *Router { + r := &Router{ + mux: http.NewServeMux(), + server: &http.Server{}, // Initialize with default server + } + for _, opt := range opts { + opt(r) + } + return r +} + +// WithServerOptions returns a RouterOptions function that sets custom http.Server options. +func WithServerOptions(opts ...func(*http.Server)) RouterOptions { + return func(r *Router) { + for _, opt := range opts { + opt(r.server) + } + } +} + +// WithTLS provides a simplified way to enable HTTPS in your router. +func WithTLS(certFile, keyFile string) RouterOptions { + return func(r *Router) { + r.server.TLSConfig = &tls.Config{} // Initialize TLS config + r.server.TLSConfig.MinVersion = tls.VersionTLS12 // Enforce secure TLS version (optional) + + var cert tls.Certificate + var err error + + if certFile == "" || keyFile == "" { + // Generate a self-signed certificate if paths are not provided + cert, err = util.LoadOrGenerateCert("./tls/tls.crt", "./tls/tls.key") + if err != nil { + log.Fatalf("failed to generate self-signed certificate: %v", err) + } + } else { + // Load certificate from provided paths + cert, err = util.LoadOrGenerateCert(certFile, keyFile) + if err != nil { + log.Fatalf("error loading TLS certificates: %v", err) + } + } + + r.server.TLSConfig.Certificates = []tls.Certificate{cert} + } +} + +// Use adds middleware to the router. Middleware functions are applied in the order they are added. +func (r *Router) Use(mw Middleware) { + r.mu.Lock() + defer r.mu.Unlock() + r.middleware = append(r.middleware, mw) +} + +// Group creates a new sub-router with a specified prefix. The sub-router inherits the middleware +// from its parent router. +func (r *Router) Group(prefix string) *Router { + r.mu.RLock() + defer r.mu.RUnlock() + return &Router{ + mux: r.mux, + middleware: append([]Middleware{}, r.middleware...), + server: r.server, + prefix: r.prefix + prefix, + } +} + +// Handle registers an HTTP handler function for a given method and pattern as introduced in +// [Routing Enhancements for Go 1.22](https://go.dev/blog/routing-enhancements) +// The handler `METHOD /pattern` on a route group with a /prefix resolves to `METHOD /prefix/pattern` +func (r *Router) Handle(methodPattern string, handler http.Handler) { + parts := strings.SplitN(methodPattern, " ", 2) + if len(parts) != 2 { + log.Fatalf("invalid method pattern: %s", methodPattern) + } + method, pattern := parts[0], parts[1] + + r.mu.RLock() + defer r.mu.RUnlock() + + // Create the final handler with all middleware applied + finalHandler := handler + for i := len(r.middleware) - 1; i >= 0; i-- { + finalHandler = r.middleware[i](finalHandler) + } + // fullPattern := r.prefix + pattern + fullPattern := fmt.Sprintf("%s %s%s", method, r.prefix, pattern) + + r.mux.Handle(fullPattern, finalHandler) +} + +// ListenAndServe starts the server, automatically choosing between HTTP and HTTPS based on TLS config. +func (r *Router) ListenAndServe(addr string) error { + fmt.Print(colorGreen + pgoAsciiArt + colorReset) + fmt.Printf("starting server on %s\n", addr) + + r.server.Addr = addr + r.server.Handler = r.applyMiddleware() + + if r.server.TLSConfig != nil { + // HTTPS + return r.server.ListenAndServeTLS("", "") // Use empty strings to auto-detect cert/key in TLSConfig + } + // HTTP + return r.server.ListenAndServe() +} + +// Shutdown gracefully shuts down the HTTP server. +func (r *Router) Shutdown(ctx context.Context) error { + log.Println("shutting down server") + return r.server.Shutdown(ctx) +} + +// applyMiddleware applies middleware to the http.Handler and returns a new http.Handler. +func (r *Router) applyMiddleware() http.Handler { + r.mu.RLock() + defer r.mu.RUnlock() + + var handler http.Handler = r.mux + for i := len(r.middleware) - 1; i >= 0; i-- { + handler = r.middleware[i](handler) + } + return handler +} + +// Constants for ASCII art and console colors +const ( + colorRed = "\033[31m" + colorGreen = "\033[32m" + colorReset = "\033[0m" + pgoAsciiArt = ` + _ __ __ _ ___ +| '_ \ / _' |/ _ \ +| |_) | (_| | (_) | +| .__/ \__, |\___/ +|_| |___/ + +` +) diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..f96addf --- /dev/null +++ b/router_test.go @@ -0,0 +1,264 @@ +package pgo + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// TestNewRouter tests the creation of a new Router +func TestNewRouter(t *testing.T) { + r := NewRouter() + if r == nil { + t.Fatal("expected router to be non-nil") + } +} + +// TestRouterHandle tests route registration and handling +func TestRouterHandle(t *testing.T) { + r := NewRouter() + r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.mux.ServeHTTP(w, req) + + if w.Result().StatusCode != http.StatusOK { + t.Errorf("expected status OK, got %v", w.Result().StatusCode) + } +} + +// TestRouterMiddleware tests adding and using middleware +func TestRouterMiddleware(t *testing.T) { + r := NewRouter() + + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("X-Test", "true") + next.ServeHTTP(w, req) + }) + }) + + r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.mux.ServeHTTP(w, req) + + if w.Header().Get("X-Test") != "true" { + t.Errorf("expected X-Test header to be set") + } +} + +// TestRouterGroup tests sub-router grouping +func TestRouterGroup(t *testing.T) { + r := NewRouter() + api := r.Group("/api") + api.Handle("GET /v1/test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/api/v1/test", nil) + w := httptest.NewRecorder() + r.mux.ServeHTTP(w, req) + + if w.Result().StatusCode != http.StatusOK { + t.Errorf("expected status OK, got %v", w.Result().StatusCode) + } +} + +// TestRouterListenAndServe tests server start and shutdown +func TestRouterListenAndServe(t *testing.T) { + r := NewRouter() + r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + serverAddr := ":8081" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := r.ListenAndServe(serverAddr); err != http.ErrServerClosed { + t.Logf("expected server to close, got %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) // Give the server a moment to start + + req, err := http.NewRequest("GET", "http://localhost:8081/test", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to send request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status OK, got %v", resp.StatusCode) + } + + if err := r.Shutdown(ctx); err != nil { + t.Fatalf("failed to shutdown server: %v", err) + } + + wg.Wait() +} + +// TestRouterTLS tests the TLS configuration of the router. +// func TestRouterTLS(t *testing.T) { +// certPath := "./test_tls/tls.crt" +// keyPath := "./test_tls/tls.key" +// _, err := util.LoadOrGenerateCert(certPath, keyPath) +// require.NoError(t, err) +// defer func() { +// os.Remove(certPath) +// os.Remove(keyPath) +// os.RemoveAll(filepath.Dir(certPath)) +// }() + +// r := NewRouter(WithTLS(certPath, keyPath)) + +// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.WriteHeader(http.StatusOK) +// }) + +// r.Handle("/test", handler) + +// // Spin up a real HTTP server for TLS testing +// go func() { +// err := r.ListenAndServe(":8443") +// require.NoError(t, err) +// }() + +// // Test HTTPS request +// req := httptest.NewRequest("GET", "https://localhost:8443/test", nil) +// req.URL.Scheme = "https" +// rec := httptest.NewRecorder() +// http.DefaultTransport.RoundTrip(req) // Use actual transport to avoid the `httptest` limitations + +// res := rec.Result() +// require.Equal(t, http.StatusOK, res.StatusCode) +// } + +// func TestListenAndServe_HTTPS(t *testing.T) { +// // Generate self-signed certificate for testing +// tempDir := t.TempDir() +// certFile := filepath.Join(tempDir, "tls.crt") +// keyFile := filepath.Join(tempDir, "tls.key") +// _, err := util.LoadOrGenerateCert(certFile, keyFile) +// require.NoError(t, err) +// defer func() { +// os.Remove(certFile) +// os.Remove(keyFile) +// }() + +// r := NewRouter(WithTLS(certFile, keyFile)) +// r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// fmt.Fprint(w, "Hello, secure world!") +// })) + +// // Start HTTPS server +// go func() { +// if err := r.ListenAndServe(":0"); err != nil && err != http.ErrServerClosed { +// t.Errorf("ListenAndServe() error = %v", err) +// } +// }() + +// // Wait for the server to start +// time.Sleep(100 * time.Millisecond) + +// // Make HTTPS request (ignoring certificate errors for self-signed cert) +// tr := &http.Transport{ +// TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Insecure for testing only +// } +// client := &http.Client{Transport: tr} +// resp, err := client.Get("https://" + r.server.Addr + "/test") +// require.NoError(t, err, "Failed to make HTTPS request") +// defer resp.Body.Close() + +// body, err := io.ReadAll(resp.Body) +// require.NoError(t, err, "Failed to read response body") + +// assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected status code 200") +// assert.Equal(t, "Hello, secure world!", string(body), "Expected response body 'Hello, secure world!'") + +// // Test Shutdown +// ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// defer cancel() +// err = r.Shutdown(ctx) +// assert.NoError(t, err, "Expected graceful shutdown") +// } + +// BenchmarkRouterHandle benchmarks route registration +func BenchmarkRouterHandle(b *testing.B) { + r := NewRouter() + for i := 0; i < b.N; i++ { + r.Handle("GET /test"+fmt.Sprintf("%d", i), http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } +} + +// BenchmarkRouterServeHTTP benchmarks serving HTTP requests +func BenchmarkRouterServeHTTP(b *testing.B) { + r := NewRouter() + r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.mux.ServeHTTP(w, req) + } +} + +// BenchmarkRouterHandleWithMiddleware benchmarks route registration with middleware +func BenchmarkRouterHandleWithMiddleware(b *testing.B) { + r := NewRouter() + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Example middleware + next.ServeHTTP(w, req) + }) + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Handle("GET /test"+fmt.Sprintf("%d", i), http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } +} + +// BenchmarkRouterServeHTTPConcurrent benchmarks serving HTTP requests concurrently +func BenchmarkRouterServeHTTPConcurrent(b *testing.B) { + r := NewRouter() + r.Handle("GET /test", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r.mux.ServeHTTP(w, req) + } + }) +}