From bad12fa97f9119f9a8c0cd1992538c5570bd062c Mon Sep 17 00:00:00 2001 From: potterbm-cb <135353141+potterbm-cb@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:31:30 -0400 Subject: [PATCH] feat(header forwarding): add HeaderForwarder and plumbing (#507) * add HeaderForwarder and plumbing * PR feedback * make add-license * make lint * update templates * codegen from templates * dedupe template function * fix linting * fix tests --- Makefile | 4 +- asserter/asserter.go | 2 +- constructor/worker/worker_test.go | 2 +- examples/server/main.go | 2 + fetcher/block.go | 2 +- go.mod | 1 + go.sum | 2 + headerforwarder/context_headers.go | 43 +++++++ headerforwarder/forwarder.go | 142 +++++++++++++++++++++++ headerforwarder/response_writer.go | 68 +++++++++++ server/api_account.go | 26 ++++- server/api_block.go | 29 ++++- server/api_call.go | 24 +++- server/api_construction.go | 59 ++++++++-- server/api_events.go | 24 +++- server/api_mempool.go | 29 ++++- server/api_network.go | 28 +++-- server/api_search.go | 27 ++++- server/routers.go | 2 + templates/server/controller-api.mustache | 22 +++- templates/server/routers.mustache | 2 + 21 files changed, 479 insertions(+), 61 deletions(-) create mode 100644 headerforwarder/context_headers.go create mode 100644 headerforwarder/forwarder.go create mode 100644 headerforwarder/response_writer.go diff --git a/Makefile b/Makefile index 5d4c626c2..9b79c0f2d 100644 --- a/Makefile +++ b/Makefile @@ -24,11 +24,11 @@ GO_MOD_PACKAGES=./types/... GO_FOLDERS=$(shell echo ${GO_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g") GO_MOD_FOLDERS=$(shell echo ${GO_MOD_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g") TEST_SCRIPT=go test ${GO_PACKAGES} -LINT_SETTINGS=golint,misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam +LINT_SETTINGS=misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam build: go build ./... - + deps: go get ./... diff --git a/asserter/asserter.go b/asserter/asserter.go index 41c5d18c5..fb5d0c285 100644 --- a/asserter/asserter.go +++ b/asserter/asserter.go @@ -374,7 +374,7 @@ func NewGenericRosettaClient( ignoreRosettaSpecValidation: true, } - //init default operation statuses for generic rosetta client + // init default operation statuses for generic rosetta client InitOperationStatus(asserter) return asserter, nil diff --git a/constructor/worker/worker_test.go b/constructor/worker/worker_test.go index e4d0c98bf..12952b03e 100644 --- a/constructor/worker/worker_test.go +++ b/constructor/worker/worker_test.go @@ -1848,7 +1848,7 @@ func TestHTTPRequestWorker(t *testing.T) { w.Header().Set("Content-Type", test.contentType) w.WriteHeader(test.statusCode) - fmt.Fprintf(w, test.response) + fmt.Fprint(w, test.response) })) defer ts.Close() diff --git a/examples/server/main.go b/examples/server/main.go index 6c41b033d..ba2cdaa23 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -40,12 +40,14 @@ func NewBlockchainRouter( networkAPIController := server.NewNetworkAPIController( networkAPIService, asserter, + nil, ) blockAPIService := services.NewBlockAPIService(network) blockAPIController := server.NewBlockAPIController( blockAPIService, asserter, + nil, ) return server.NewRouter(networkAPIController, blockAPIController) diff --git a/fetcher/block.go b/fetcher/block.go index fadce34ab..7b028ac00 100644 --- a/fetcher/block.go +++ b/fetcher/block.go @@ -254,7 +254,7 @@ func (f *Fetcher) UnsafeBlock( } // Exit early if no need to fetch txs - if blockResponse.OtherTransactions == nil || len(blockResponse.OtherTransactions) == 0 { + if len(blockResponse.OtherTransactions) == 0 { return blockResponse.Block, nil } diff --git a/go.mod b/go.mod index cb73cdfe4..3c4a6bec0 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/dgraph-io/badger/v2 v2.2007.4 github.com/ethereum/go-ethereum v1.10.21 github.com/fatih/color v1.13.0 + github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/lucasjones/reggen v0.0.0-20180717132126-cdb49ff09d77 github.com/neilotoole/errgroup v0.1.6 diff --git a/go.sum b/go.sum index e6ae8c979..947d5e32b 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/headerforwarder/context_headers.go b/headerforwarder/context_headers.go new file mode 100644 index 000000000..a4c1b90d3 --- /dev/null +++ b/headerforwarder/context_headers.go @@ -0,0 +1,43 @@ +// Copyright 2024 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package headerforwarder + +import ( + "context" + "net/http" + + "github.com/google/uuid" +) + +type contextKey string + +const requestIDKey = contextKey("request_id") + +func ContextWithRosettaID(ctx context.Context) context.Context { + return context.WithValue(ctx, requestIDKey, uuid.NewString()) +} + +func RosettaIDFromContext(ctx context.Context) string { + return ctx.Value(requestIDKey).(string) +} + +func RosettaIDFromRequest(r *http.Request) string { + switch value := r.Context().Value(requestIDKey).(type) { + case string: + return value + default: + return "" + } +} diff --git a/headerforwarder/forwarder.go b/headerforwarder/forwarder.go new file mode 100644 index 000000000..85cb78146 --- /dev/null +++ b/headerforwarder/forwarder.go @@ -0,0 +1,142 @@ +// Copyright 2024 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package headerforwarder + +import ( + "net/http" +) + +// HeaderExtractingTransport is a utility to help a rosetta server forward headers to and from +// native node requests. It implements several interfaces to achieve that: +// - http.RoundTripper: this can be used to create an http Client that will automatically save headers +// if necessary +// - func(http.Handler) http.Handler: this can be used to wrap an http.Handler to set headers +// on the response +// +// the headers can be requested later. +// +// TODO: this should expire entries after a certain amount of time +type HeaderForwarder struct { + requestHeaders map[string]http.Header + interestingHeaders []string + actualTransport http.RoundTripper +} + +func NewHeaderForwarder(interestingHeaders []string, transport http.RoundTripper) *HeaderForwarder { + return &HeaderForwarder{ + requestHeaders: make(map[string]http.Header), + interestingHeaders: interestingHeaders, + actualTransport: transport, + } +} + +// RoundTrip implements http.RoundTripper and will be used to construct an http Client which +// saves the native node response headers if necessary. +func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := hf.actualTransport.RoundTrip(req) + + if err == nil && hf.shouldRememberHeaders(req, resp) { + hf.rememberHeaders(req, resp) + } + + return resp, err +} + +// shouldRememberHeaders is called to determine if response headers should be remembered for a +// given request. Response headers will only be remembered if the request does not contain all of +// the interesting headers and the response contains at least one of the interesting headers. +// +// It should be noted that the request and response here are for a request to the native node, +// not a request to the Rosetta server. +func (hf *HeaderForwarder) shouldRememberHeaders(req *http.Request, resp *http.Response) bool { + requestHasAllHeaders := true + responseHasSomeHeaders := false + + for _, interestingHeader := range hf.interestingHeaders { + _, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)] + _, responseHasHeader := resp.Header[http.CanonicalHeaderKey(interestingHeader)] + + if !requestHasHeader { + requestHasAllHeaders = false + } + + if responseHasHeader { + responseHasSomeHeaders = true + } + } + + // only remember headers if the request does not contain all of the interesting headers and the + // response contains at least one + return !requestHasAllHeaders && responseHasSomeHeaders +} + +// rememberHeaders is called to save the native node response headers. The request object +// here is a native node request (constructed by go-ethereum for geth-based rosetta implementations). +// The response object is a native node response. +func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Response) { + ctx := req.Context() + // rosettaRequestID := services.osettaIdFromContext(ctx) + rosettaRequestID := RosettaIDFromContext(ctx) + + // Only remember interesting headers + headersToRemember := make(http.Header) + for _, interestingHeader := range hf.interestingHeaders { + headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader)) + } + + hf.requestHeaders[rosettaRequestID] = headersToRemember +} + +// GetResponseHeaders returns any native node response headers that were recorded for a request ID. +func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) { + headers, ok := hf.requestHeaders[rosettaRequestID] + + // Delete the headers from the map after they are retrieved + // This is safe to call even if the key doesn't exist + delete(hf.requestHeaders, rosettaRequestID) + + return headers, ok +} + +// HeaderForwarderHandler will allow the next handler to serve the request, and then checks +// if there are any native node response headers recorded for the request. If there are, it will set +// those headers on the response +func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // add a unique ID to the request context, and make a new request for it + requestWithID := hf.WithRequestID(r) + + // Serve the request + // NOTE: ResponseWriter::WriteHeader() WILL be called here, so we can't set headers after this happens + // We include a wrapper around the response writer that allows us to set headers just before + // WriteHeader is called + wrappedResponseWriter := NewResponseWriter( + w, + RosettaIDFromRequest(requestWithID), + hf.getResponseHeaders, + ) + next.ServeHTTP(wrappedResponseWriter, requestWithID) + }) +} + +// WithRequestID adds a unique ID to the request context. A new request is returned that contains the +// new context +func (hf *HeaderForwarder) WithRequestID(req *http.Request) *http.Request { + ctx := req.Context() + ctxWithID := ContextWithRosettaID(ctx) + requestWithID := req.WithContext(ctxWithID) + + return requestWithID +} diff --git a/headerforwarder/response_writer.go b/headerforwarder/response_writer.go new file mode 100644 index 000000000..c8ce1f394 --- /dev/null +++ b/headerforwarder/response_writer.go @@ -0,0 +1,68 @@ +// Copyright 2024 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package headerforwarder + +import ( + "net/http" +) + +// ResponseWriter is a wrapper around a http.ResponseWriter that allows us to set headers +// just before the WriteHeader function is called. These headers will be extracted from native node +// responses, and set on the rosetta response. +type ResponseWriter struct { + writer http.ResponseWriter + RosettaRequestID string + GetAdditionalHeaders func(string) (http.Header, bool) +} + +func NewResponseWriter( + writer http.ResponseWriter, + rosettaRequestID string, + getAdditionalHeaders func(string) (http.Header, bool), +) *ResponseWriter { + return &ResponseWriter{ + writer: writer, + RosettaRequestID: rosettaRequestID, + GetAdditionalHeaders: getAdditionalHeaders, + } +} + +// Header passes through to the underlying ResponseWriter instance +func (hfrw *ResponseWriter) Header() http.Header { + return hfrw.writer.Header() +} + +// Write passes through to the underlying ResponseWriter instance +func (hfrw *ResponseWriter) Write(b []byte) (int, error) { + return hfrw.writer.Write(b) +} + +// WriteHeader will add any final extracted headers, and then pass through to the underlying ResponseWriter instance +func (hfrw *ResponseWriter) WriteHeader(statusCode int) { + hfrw.AddExtractedHeaders() + hfrw.writer.WriteHeader(statusCode) +} + +func (hfrw *ResponseWriter) AddExtractedHeaders() { + headers, hasAdditionalHeaders := hfrw.GetAdditionalHeaders(hfrw.RosettaRequestID) + + if hasAdditionalHeaders { + for key, values := range headers { + for _, value := range values { + hfrw.writer.Header().Add(key, value) + } + } + } +} diff --git a/server/api_account.go b/server/api_account.go index af1e881aa..a98d45658 100644 --- a/server/api_account.go +++ b/server/api_account.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A AccountAPIController binds http requests to an api service and writes the service results to // the http response type AccountAPIController struct { - service AccountAPIServicer - asserter *asserter.Asserter + service AccountAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewAccountAPIController creates a default api controller func NewAccountAPIController( s AccountAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &AccountAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -61,6 +65,16 @@ func (c *AccountAPIController) Routes() Routes { } } +func (c *AccountAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // AccountBalance - Get an Account's Balance func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Request) { accountBalanceRequest := &types.AccountBalanceRequest{} @@ -81,7 +95,7 @@ func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Req return } - result, serviceErr := c.service.AccountBalance(r.Context(), accountBalanceRequest) + result, serviceErr := c.service.AccountBalance(c.ContextFromRequest(r), accountBalanceRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -111,7 +125,7 @@ func (c *AccountAPIController) AccountCoins(w http.ResponseWriter, r *http.Reque return } - result, serviceErr := c.service.AccountCoins(r.Context(), accountCoinsRequest) + result, serviceErr := c.service.AccountCoins(c.ContextFromRequest(r), accountCoinsRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_block.go b/server/api_block.go index a3fe4cf63..a3d22732d 100644 --- a/server/api_block.go +++ b/server/api_block.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A BlockAPIController binds http requests to an api service and writes the service results to the // http response type BlockAPIController struct { - service BlockAPIServicer - asserter *asserter.Asserter + service BlockAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewBlockAPIController creates a default api controller func NewBlockAPIController( s BlockAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &BlockAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -61,6 +65,16 @@ func (c *BlockAPIController) Routes() Routes { } } +func (c *BlockAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // Block - Get a Block func (c *BlockAPIController) Block(w http.ResponseWriter, r *http.Request) { blockRequest := &types.BlockRequest{} @@ -81,7 +95,7 @@ func (c *BlockAPIController) Block(w http.ResponseWriter, r *http.Request) { return } - result, serviceErr := c.service.Block(r.Context(), blockRequest) + result, serviceErr := c.service.Block(c.ContextFromRequest(r), blockRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -111,7 +125,10 @@ func (c *BlockAPIController) BlockTransaction(w http.ResponseWriter, r *http.Req return } - result, serviceErr := c.service.BlockTransaction(r.Context(), blockTransactionRequest) + result, serviceErr := c.service.BlockTransaction( + c.ContextFromRequest(r), + blockTransactionRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_call.go b/server/api_call.go index 20abff660..7e9656360 100644 --- a/server/api_call.go +++ b/server/api_call.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A CallAPIController binds http requests to an api service and writes the service results to the // http response type CallAPIController struct { - service CallAPIServicer - asserter *asserter.Asserter + service CallAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewCallAPIController creates a default api controller func NewCallAPIController( s CallAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &CallAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -55,6 +59,16 @@ func (c *CallAPIController) Routes() Routes { } } +func (c *CallAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // Call - Make a Network-Specific Procedure Call func (c *CallAPIController) Call(w http.ResponseWriter, r *http.Request) { callRequest := &types.CallRequest{} @@ -75,7 +89,7 @@ func (c *CallAPIController) Call(w http.ResponseWriter, r *http.Request) { return } - result, serviceErr := c.service.Call(r.Context(), callRequest) + result, serviceErr := c.service.Call(c.ContextFromRequest(r), callRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_construction.go b/server/api_construction.go index d47d45a0a..f1d91c2a4 100644 --- a/server/api_construction.go +++ b/server/api_construction.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A ConstructionAPIController binds http requests to an api service and writes the service results // to the http response type ConstructionAPIController struct { - service ConstructionAPIServicer - asserter *asserter.Asserter + service ConstructionAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewConstructionAPIController creates a default api controller func NewConstructionAPIController( s ConstructionAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &ConstructionAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -97,6 +101,16 @@ func (c *ConstructionAPIController) Routes() Routes { } } +func (c *ConstructionAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // ConstructionCombine - Create Network Transaction from Signatures func (c *ConstructionAPIController) ConstructionCombine(w http.ResponseWriter, r *http.Request) { constructionCombineRequest := &types.ConstructionCombineRequest{} @@ -117,7 +131,10 @@ func (c *ConstructionAPIController) ConstructionCombine(w http.ResponseWriter, r return } - result, serviceErr := c.service.ConstructionCombine(r.Context(), constructionCombineRequest) + result, serviceErr := c.service.ConstructionCombine( + c.ContextFromRequest(r), + constructionCombineRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -147,7 +164,10 @@ func (c *ConstructionAPIController) ConstructionDerive(w http.ResponseWriter, r return } - result, serviceErr := c.service.ConstructionDerive(r.Context(), constructionDeriveRequest) + result, serviceErr := c.service.ConstructionDerive( + c.ContextFromRequest(r), + constructionDeriveRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -177,7 +197,10 @@ func (c *ConstructionAPIController) ConstructionHash(w http.ResponseWriter, r *h return } - result, serviceErr := c.service.ConstructionHash(r.Context(), constructionHashRequest) + result, serviceErr := c.service.ConstructionHash( + c.ContextFromRequest(r), + constructionHashRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -207,7 +230,10 @@ func (c *ConstructionAPIController) ConstructionMetadata(w http.ResponseWriter, return } - result, serviceErr := c.service.ConstructionMetadata(r.Context(), constructionMetadataRequest) + result, serviceErr := c.service.ConstructionMetadata( + c.ContextFromRequest(r), + constructionMetadataRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -237,7 +263,10 @@ func (c *ConstructionAPIController) ConstructionParse(w http.ResponseWriter, r * return } - result, serviceErr := c.service.ConstructionParse(r.Context(), constructionParseRequest) + result, serviceErr := c.service.ConstructionParse( + c.ContextFromRequest(r), + constructionParseRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -267,7 +296,10 @@ func (c *ConstructionAPIController) ConstructionPayloads(w http.ResponseWriter, return } - result, serviceErr := c.service.ConstructionPayloads(r.Context(), constructionPayloadsRequest) + result, serviceErr := c.service.ConstructionPayloads( + c.ContextFromRequest(r), + constructionPayloadsRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -298,7 +330,7 @@ func (c *ConstructionAPIController) ConstructionPreprocess(w http.ResponseWriter } result, serviceErr := c.service.ConstructionPreprocess( - r.Context(), + c.ContextFromRequest(r), constructionPreprocessRequest, ) if serviceErr != nil { @@ -330,7 +362,10 @@ func (c *ConstructionAPIController) ConstructionSubmit(w http.ResponseWriter, r return } - result, serviceErr := c.service.ConstructionSubmit(r.Context(), constructionSubmitRequest) + result, serviceErr := c.service.ConstructionSubmit( + c.ContextFromRequest(r), + constructionSubmitRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_events.go b/server/api_events.go index 73a1d5d5a..68dbeb1b0 100644 --- a/server/api_events.go +++ b/server/api_events.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A EventsAPIController binds http requests to an api service and writes the service results to the // http response type EventsAPIController struct { - service EventsAPIServicer - asserter *asserter.Asserter + service EventsAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewEventsAPIController creates a default api controller func NewEventsAPIController( s EventsAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &EventsAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -55,6 +59,16 @@ func (c *EventsAPIController) Routes() Routes { } } +func (c *EventsAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // EventsBlocks - [INDEXER] Get a range of BlockEvents func (c *EventsAPIController) EventsBlocks(w http.ResponseWriter, r *http.Request) { eventsBlocksRequest := &types.EventsBlocksRequest{} @@ -75,7 +89,7 @@ func (c *EventsAPIController) EventsBlocks(w http.ResponseWriter, r *http.Reques return } - result, serviceErr := c.service.EventsBlocks(r.Context(), eventsBlocksRequest) + result, serviceErr := c.service.EventsBlocks(c.ContextFromRequest(r), eventsBlocksRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_mempool.go b/server/api_mempool.go index 774fb1766..789f2b9ff 100644 --- a/server/api_mempool.go +++ b/server/api_mempool.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A MempoolAPIController binds http requests to an api service and writes the service results to // the http response type MempoolAPIController struct { - service MempoolAPIServicer - asserter *asserter.Asserter + service MempoolAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewMempoolAPIController creates a default api controller func NewMempoolAPIController( s MempoolAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &MempoolAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -61,6 +65,16 @@ func (c *MempoolAPIController) Routes() Routes { } } +func (c *MempoolAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // Mempool - Get All Mempool Transactions func (c *MempoolAPIController) Mempool(w http.ResponseWriter, r *http.Request) { networkRequest := &types.NetworkRequest{} @@ -81,7 +95,7 @@ func (c *MempoolAPIController) Mempool(w http.ResponseWriter, r *http.Request) { return } - result, serviceErr := c.service.Mempool(r.Context(), networkRequest) + result, serviceErr := c.service.Mempool(c.ContextFromRequest(r), networkRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -111,7 +125,10 @@ func (c *MempoolAPIController) MempoolTransaction(w http.ResponseWriter, r *http return } - result, serviceErr := c.service.MempoolTransaction(r.Context(), mempoolTransactionRequest) + result, serviceErr := c.service.MempoolTransaction( + c.ContextFromRequest(r), + mempoolTransactionRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_network.go b/server/api_network.go index 75a8f0c19..323c8aa9d 100644 --- a/server/api_network.go +++ b/server/api_network.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A NetworkAPIController binds http requests to an api service and writes the service results to // the http response type NetworkAPIController struct { - service NetworkAPIServicer - asserter *asserter.Asserter + service NetworkAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewNetworkAPIController creates a default api controller func NewNetworkAPIController( s NetworkAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &NetworkAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -67,6 +71,16 @@ func (c *NetworkAPIController) Routes() Routes { } } +func (c *NetworkAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // NetworkList - Get List of Available Networks func (c *NetworkAPIController) NetworkList(w http.ResponseWriter, r *http.Request) { metadataRequest := &types.MetadataRequest{} @@ -87,7 +101,7 @@ func (c *NetworkAPIController) NetworkList(w http.ResponseWriter, r *http.Reques return } - result, serviceErr := c.service.NetworkList(r.Context(), metadataRequest) + result, serviceErr := c.service.NetworkList(c.ContextFromRequest(r), metadataRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -117,7 +131,7 @@ func (c *NetworkAPIController) NetworkOptions(w http.ResponseWriter, r *http.Req return } - result, serviceErr := c.service.NetworkOptions(r.Context(), networkRequest) + result, serviceErr := c.service.NetworkOptions(c.ContextFromRequest(r), networkRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) @@ -147,7 +161,7 @@ func (c *NetworkAPIController) NetworkStatus(w http.ResponseWriter, r *http.Requ return } - result, serviceErr := c.service.NetworkStatus(r.Context(), networkRequest) + result, serviceErr := c.service.NetworkStatus(c.ContextFromRequest(r), networkRequest) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/api_search.go b/server/api_search.go index a9e025491..70ea7e54c 100644 --- a/server/api_search.go +++ b/server/api_search.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -28,18 +29,21 @@ import ( // A SearchAPIController binds http requests to an api service and writes the service results to the // http response type SearchAPIController struct { - service SearchAPIServicer - asserter *asserter.Asserter + service SearchAPIServicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // NewSearchAPIController creates a default api controller func NewSearchAPIController( s SearchAPIServicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &SearchAPIController{ - service: s, - asserter: asserter, + service: s, + asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -55,6 +59,16 @@ func (c *SearchAPIController) Routes() Routes { } } +func (c *SearchAPIController) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx +} + // SearchTransactions - [INDEXER] Search for Transactions func (c *SearchAPIController) SearchTransactions(w http.ResponseWriter, r *http.Request) { searchTransactionsRequest := &types.SearchTransactionsRequest{} @@ -75,7 +89,10 @@ func (c *SearchAPIController) SearchTransactions(w http.ResponseWriter, r *http. return } - result, serviceErr := c.service.SearchTransactions(r.Context(), searchTransactionsRequest) + result, serviceErr := c.service.SearchTransactions( + c.ContextFromRequest(r), + searchTransactionsRequest, + ) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) diff --git a/server/routers.go b/server/routers.go index 27e8b95c9..a762117d9 100644 --- a/server/routers.go +++ b/server/routers.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/json" "net/http" @@ -37,6 +38,7 @@ type Routes []Route // Router defines the required methods for retrieving api routes type Router interface { Routes() Routes + ContextFromRequest(*http.Request) context.Context } // CorsMiddleware handles CORS and ensures OPTIONS requests are diff --git a/templates/server/controller-api.mustache b/templates/server/controller-api.mustache index c11d6a071..5d0f422a1 100644 --- a/templates/server/controller-api.mustache +++ b/templates/server/controller-api.mustache @@ -2,6 +2,7 @@ package {{packageName}} import ( + "context" "encoding/json" "net/http" "strings" @@ -12,18 +13,21 @@ import ( // A {{classname}}Controller binds http requests to an api service and writes the service results to the http response type {{classname}}Controller struct { - service {{classname}}Servicer - asserter *asserter.Asserter + service {{classname}}Servicer + asserter *asserter.Asserter + contextFromRequest func(*http.Request) context.Context } // New{{classname}}Controller creates a default api controller func New{{classname}}Controller( s {{classname}}Servicer, asserter *asserter.Asserter, + contextFromRequest func(*http.Request) context.Context, ) Router { return &{{classname}}Controller{ service: s, asserter: asserter, + contextFromRequest: contextFromRequest, } } @@ -37,6 +41,16 @@ func (c *{{classname}}Controller) Routes() Routes { c.{{operationId}}, },{{/operation}}{{/operations}} } +} + +func (c *{{classname}}Controller) ContextFromRequest(r *http.Request) context.Context { + ctx := r.Context() + + if c.contextFromRequest != nil { + ctx = c.contextFromRequest(r) + } + + return ctx }{{#operations}}{{#operation}} // {{nickname}} - {{{summary}}} @@ -61,12 +75,12 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re } {{/isBodyParam}}{{/allParams}} - result, serviceErr := c.service.{{nickname}}(r.Context(), {{#allParams}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/allParams}}) + result, serviceErr := c.service.{{nickname}}(c.ContextFromRequest(r), {{#allParams}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/allParams}}) if serviceErr != nil { EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w) return } - + EncodeJSONResponse(result, http.StatusOK, w) }{{/operation}}{{/operations}} diff --git a/templates/server/routers.mustache b/templates/server/routers.mustache index 8394a55e1..c0046c088 100644 --- a/templates/server/routers.mustache +++ b/templates/server/routers.mustache @@ -2,6 +2,7 @@ package {{packageName}} import ( + "context" "encoding/json" "net/http" @@ -22,6 +23,7 @@ type Routes []Route // Router defines the required methods for retrieving api routes type Router interface { Routes() Routes + ContextFromRequest(*http.Request) context.Context } // CorsMiddleware handles CORS and ensures OPTIONS requests are