diff --git a/api/openapi.yml b/api/openapi.yml index 107baa7..7e6a1bb 100644 --- a/api/openapi.yml +++ b/api/openapi.yml @@ -29,6 +29,8 @@ paths: responses: 302: description: PURL resolved, use the `Location` header to navigate to the target. + 400: + $ref: "#/components/responses/BadRequest" 404: description: PURL not resolved, it probably does not exist. 410: @@ -56,6 +58,8 @@ paths: responses: 204: description: PURL created or updated. + 400: + $ref: "#/components/responses/BadRequest" 404: description: PURL not resolved, it probably does not exist. @@ -76,7 +80,18 @@ components: schema: $ref: '#/components/schemas/Named' + responses: + BadRequest: + description: Bad request, see response body for details. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + schemas: + Error: + type: string + Named: type: string description: A name safe for URL usage diff --git a/api/server.go b/api/server.go index a061d22..804b6bc 100644 --- a/api/server.go +++ b/api/server.go @@ -29,7 +29,7 @@ func (s *Server) Resolve(ctx *gin.Context) { } } -func (s *Server) Save(ctx *gin.Context) { +func (s *Server) SavePURL(ctx *gin.Context) { domain := ctx.Param("domain") name := ctx.Param("name") type body struct { diff --git a/api/server_routes.go b/api/server_routes.go index 596c1da..b2957e9 100644 --- a/api/server_routes.go +++ b/api/server_routes.go @@ -1,11 +1,16 @@ package api -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" +) func SetupRouting(r gin.IRouter, s *Server) { + r.Use(validPathVar("domain", regexNamed)) + r.Use(validPathVar("name", regexNamed)) + // Resolve endpoints r.GET("/r/:domain/:name", s.Resolve) // Admin endpoints - r.PUT("/a/domains/:domain/purls/:name", s.Save) + r.PUT("/a/domains/:domain/purls/:name", s.SavePURL) } diff --git a/api/validate.go b/api/validate.go new file mode 100644 index 0000000..12a4fa9 --- /dev/null +++ b/api/validate.go @@ -0,0 +1,30 @@ +package api + +import ( + "fmt" + "regexp" + + "github.com/gin-gonic/gin" +) + +var regexNamed *regexp.Regexp + +func init() { + // regexNamed is used to validate everything that has a name. See OpenAPI + // for more information. + regexNamed = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) +} + +// validPathVar is a middleware that validates a path variable against a +// regular expression. If the path variable does not match the regular +// expression, the middleware aborts the request with status 400. +func validPathVar(key string, regex *regexp.Regexp) gin.HandlerFunc { + return func(context *gin.Context) { + if !regex.MatchString(context.Param(key)) { + err := fmt.Sprintf("path variable %q does not match regex %s", key, regex.String()) + context.AbortWithStatusJSON(400, err) + return + } + context.Next() + } +} diff --git a/tests/driver/http.go b/tests/driver/http.go index c76c114..1168383 100644 --- a/tests/driver/http.go +++ b/tests/driver/http.go @@ -79,6 +79,8 @@ func (driver *HTTPDriver) CreatePurl(purl *dsl.PURL) error { switch res.StatusCode { case http.StatusNoContent: return nil + case http.StatusBadRequest: + return fmt.Errorf("%w: status %d returned", dsl.ErrBadRequest, res.StatusCode) default: return fmt.Errorf("unexpected status code: %d", res.StatusCode) } diff --git a/tests/dsl/errs.go b/tests/dsl/errs.go index 2af4a7d..0c13766 100644 --- a/tests/dsl/errs.go +++ b/tests/dsl/errs.go @@ -3,5 +3,6 @@ package dsl import "errors" var ( - ErrNotFound = errors.New("not found") + ErrNotFound = errors.New("not found") + ErrBadRequest = errors.New("bad request") ) diff --git a/tests/http_test.go b/tests/http_test.go index a4e327f..7443236 100644 --- a/tests/http_test.go +++ b/tests/http_test.go @@ -19,5 +19,9 @@ func TestWithHTTPDriver(t *testing.T) { api.SetupRouting(handler, server) testServer := httptest.NewServer(handler) - specs.TestResolver(t, driver.NewHTTPDriver(testServer.URL, http.DefaultTransport)) + + dr := driver.NewHTTPDriver(testServer.URL, http.DefaultTransport) + + specs.TestResolver(t, dr) + specs.TestAdministration(t, dr) } diff --git a/tests/mock_test.go b/tests/mock_test.go index ce7afab..cb4fe7e 100644 --- a/tests/mock_test.go +++ b/tests/mock_test.go @@ -1,9 +1,10 @@ package tests import ( + "testing" + "github.com/fabiante/persurl/tests/driver" "github.com/fabiante/persurl/tests/specs" - "testing" ) func TestWithMockDriver(t *testing.T) { diff --git a/tests/specs/admin.go b/tests/specs/admin.go new file mode 100644 index 0000000..3fe6452 --- /dev/null +++ b/tests/specs/admin.go @@ -0,0 +1,41 @@ +package specs + +import ( + "fmt" + "testing" + + "github.com/fabiante/persurl/tests/dsl" + "github.com/stretchr/testify/require" +) + +func TestAdministration(t *testing.T, admin dsl.AdminAPI) { + t.Run("administration", func(t *testing.T) { + t.Run("can't create invalid PURL", func(t *testing.T) { + invalid := []*dsl.PURL{ + // empty + dsl.NewPURL("", "valid", mustParseURL("example.com")), + dsl.NewPURL("valid", "", mustParseURL("example.com")), + // whitespace + dsl.NewPURL("a b", "valid", mustParseURL("example.com")), + dsl.NewPURL("valid", "a b", mustParseURL("example.com")), + // url encoded whitespace + dsl.NewPURL("a%20b", "valid", mustParseURL("example.com")), + dsl.NewPURL("valid", "a%20b", mustParseURL("example.com")), + // random characters + dsl.NewPURL("^", "valid", mustParseURL("example.com")), + dsl.NewPURL("~", "valid", mustParseURL("example.com")), + dsl.NewPURL(":", "valid", mustParseURL("example.com")), + dsl.NewPURL(",", "valid", mustParseURL("example.com")), + dsl.NewPURL("`", "valid", mustParseURL("example.com")), + } + + for i, purl := range invalid { + t.Run(fmt.Sprintf("invalid[%d]", i), func(t *testing.T) { + err := admin.CreatePurl(purl) + require.Error(t, err) + require.ErrorIs(t, err, dsl.ErrBadRequest) + }) + } + }) + }) +} diff --git a/tests/specs/resolve.go b/tests/specs/resolve.go index edae78b..c49f4d3 100644 --- a/tests/specs/resolve.go +++ b/tests/specs/resolve.go @@ -17,21 +17,23 @@ type ResolveAPI interface { } func TestResolver(t *testing.T, resolver ResolveAPI) { - t.Run("does not resolve non-existant PURL", func(t *testing.T) { - purl, err := resolver.ResolvePURL("something-very-stupid", "should-not-exist") - require.Error(t, err) - require.ErrorIs(t, err, dsl.ErrNotFound) - require.Nil(t, purl) - }) - - t.Run("resolves existing PURL", func(t *testing.T) { - domain := "my-domain" - name := "my-name" - - dsl.GivenExistingPURL(t, resolver, dsl.NewPURL(domain, name, mustParseURL("https://google.com"))) - - purl, err := resolver.ResolvePURL(domain, name) - require.NoError(t, err) - require.NotNil(t, purl) + t.Run("resolver", func(t *testing.T) { + t.Run("does not resolve non-existant PURL", func(t *testing.T) { + purl, err := resolver.ResolvePURL("something-very-stupid", "should-not-exist") + require.Error(t, err) + require.ErrorIs(t, err, dsl.ErrNotFound) + require.Nil(t, purl) + }) + + t.Run("resolves existing PURL", func(t *testing.T) { + domain := "my-domain" + name := "my-name" + + dsl.GivenExistingPURL(t, resolver, dsl.NewPURL(domain, name, mustParseURL("https://google.com"))) + + purl, err := resolver.ResolvePURL(domain, name) + require.NoError(t, err) + require.NotNil(t, purl) + }) }) }