diff --git a/handlers.go b/handlers.go index ea19c4f..b2fbe16 100644 --- a/handlers.go +++ b/handlers.go @@ -112,6 +112,12 @@ func handlePostGitCredentials(tokenVendor vendor.PipelineTokenVendor) http.Handl }) } +func maxRequestSize(limit int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.MaxBytesHandler(next, limit) + } +} + func requestError(w http.ResponseWriter, statusCode int) { http.Error(w, http.StatusText(statusCode), statusCode) } diff --git a/handlers_test.go b/handlers_test.go index ec9e36d..03adae6 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -199,3 +200,41 @@ func tvFails(err error) vendor.PipelineTokenVendor { return nil, err }) } + +func TestMaxRequestSizeMiddleware(t *testing.T) { + + mw := maxRequestSize(10) + + var readError error + var readBytes int64 + + innerHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + readBytes, readError = io.CopyN(io.Discard, r.Body, 5*1024*1024) + + status := http.StatusOK + if readError != nil { + status = http.StatusBadRequest + } + + w.WriteHeader(status) + }) + + handler := mw(innerHandler) + + body := bytes.NewBufferString("0123456789n123456789") + req, err := http.NewRequest("POST", "/git-credentials", body) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + // act + handler.ServeHTTP(rr, req) + + // assert + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.ErrorContains(t, readError, "http: request body too large") + assert.Equal(t, int64(10), readBytes) + + respBody := rr.Body.String() + assert.Equal(t, "", respBody) +} diff --git a/main.go b/main.go index 50c7fba..90f4f44 100644 --- a/main.go +++ b/main.go @@ -32,7 +32,11 @@ func configureServerRoutes(cfg config.Config) (http.Handler, error) { return nil, fmt.Errorf("authorizer configuration failed: %w", err) } - authorized := alice.New(authorizer) + // The request body size is fairly limited to prevent accidental or + // deliberate abuse. Given the current API shape, this is not configurable. + requestLimitBytes := int64(20 << 10) // 20 KB + + authorized := alice.New(maxRequestSize(requestLimitBytes), authorizer) // setup token handler and dependencies bk := buildkite.New(cfg.Buildkite) @@ -79,8 +83,9 @@ func launchServer() error { } server := &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.Server.Port), - Handler: handler, + Addr: fmt.Sprintf(":%d", cfg.Server.Port), + Handler: handler, + MaxHeaderBytes: 20 << 10, // 20 KB } shutdownTelemetry, err := observe.Configure(ctx, cfg.Observe)