diff --git a/pkg/backends/backends_test.go b/pkg/backends/backends_test.go index 387e67eb4..2b69ed2a9 100644 --- a/pkg/backends/backends_test.go +++ b/pkg/backends/backends_test.go @@ -21,12 +21,12 @@ import ( ho "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck/options" bo "github.com/trickstercache/trickster/v2/pkg/backends/options" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" ) func TestBackends(t *testing.T) { - cl, _ := New("test1", bo.New(), nil, router.NewRouter(), nil) + cl, _ := New("test1", bo.New(), nil, lm.NewRouter(), nil) o := Backends{"test1": cl} c := o.Get("test1") @@ -77,11 +77,11 @@ func TestStartHealthChecks(t *testing.T) { // 1: rule / Virtual provider o1 := bo.New() o1.Provider = "rule" - c1, _ := New("test1", o1, nil, router.NewRouter(), nil) + c1, _ := New("test1", o1, nil, lm.NewRouter(), nil) // 2: non-virtual provider with no health check options o2 := bo.New() - c2, _ := New("test2", o2, nil, router.NewRouter(), nil) + c2, _ := New("test2", o2, nil, lm.NewRouter(), nil) b := Backends{"test1": c1} _, err := b.StartHealthChecks(nil) diff --git a/pkg/backends/irondb/url.go b/pkg/backends/irondb/url.go index 04f403237..f98de9c43 100644 --- a/pkg/backends/irondb/url.go +++ b/pkg/backends/irondb/url.go @@ -45,7 +45,7 @@ func (c *Client) FastForwardRequest(r *http.Request) (*http.Request, error) { rsc := request.GetResources(r) if rsc == nil || rsc.PathConfig == nil { - return nil, tkerr.ErrMissingPathconfig + return nil, tkerr.ErrMissingPathConfig } switch rsc.PathConfig.HandlerName { @@ -68,7 +68,7 @@ func (c *Client) ParseTimeRangeQuery( rsc := request.GetResources(r) if rsc == nil || rsc.PathConfig == nil { - return nil, nil, false, tkerr.ErrMissingPathconfig + return nil, nil, false, tkerr.ErrMissingPathConfig } var trq *timeseries.TimeRangeQuery diff --git a/pkg/backends/reverseproxy/routes.go b/pkg/backends/reverseproxy/routes.go index f40eafb05..d611c4332 100644 --- a/pkg/backends/reverseproxy/routes.go +++ b/pkg/backends/reverseproxy/routes.go @@ -48,7 +48,7 @@ func (c *Client) DefaultPathConfigs(o *bo.Options) map[string]*po.Options { "/-" + strings.Join(am, "-"): { Path: "/", HandlerName: "proxy", - Methods: methods.AllHTTPMethods(), + Methods: am, MatchType: matching.PathMatchTypePrefix, MatchTypeName: "prefix", }, diff --git a/pkg/backends/timeseries_backend_test.go b/pkg/backends/timeseries_backend_test.go index 607abf564..81880e6a0 100644 --- a/pkg/backends/timeseries_backend_test.go +++ b/pkg/backends/timeseries_backend_test.go @@ -20,11 +20,11 @@ import ( "testing" bo "github.com/trickstercache/trickster/v2/pkg/backends/options" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" ) func TestNewTimeseriesBackend(t *testing.T) { - tb, _ := NewTimeseriesBackend("test1", bo.New(), nil, router.NewRouter(), nil, nil) + tb, _ := NewTimeseriesBackend("test1", bo.New(), nil, lm.NewRouter(), nil, nil) if tb.Name() != "test1" { t.Error("expected test1 got", tb.Name()) } diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index f193dcb5a..7db8e111e 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -24,7 +24,7 @@ const ( // DefaultHealthHandlerPath defines the default path for the Health Handler DefaultHealthHandlerPath = "/trickster/health" // DefaultPurgeKeyHandlerPath defines the default path for the Cache Purge (by Key) Handler - DefaultPurgeKeyHandlerPath = "/trickster/purge/key/{backend}/{key}" + DefaultPurgeKeyHandlerPath = "/trickster/purge/key/" // DefaultPurgePathHandlerPath defines the default path for the Cache Purge (by Path) Handler // Requires ?backend={backend}&path={path} DefaultPurgePathHandlerPath = "/trickster/purge/path" diff --git a/pkg/config/errors.go b/pkg/config/errors.go deleted file mode 100644 index a1e00a412..000000000 --- a/pkg/config/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package config - -import "errors" - -var ErrNoValidBackends = errors.New("no valid backends configured") diff --git a/pkg/config/loader.go b/pkg/config/loader.go index ef48206e5..fd8e16621 100644 --- a/pkg/config/loader.go +++ b/pkg/config/loader.go @@ -22,6 +22,7 @@ import ( bo "github.com/trickstercache/trickster/v2/pkg/backends/options" "github.com/trickstercache/trickster/v2/pkg/cache/negative" + "github.com/trickstercache/trickster/v2/pkg/errors" ) // Load returns the Application Configuration, starting with a default config, @@ -71,7 +72,7 @@ func Load(applicationName string, applicationVersion string, arguments []string) } if len(c.Backends) == 0 { - return nil, flags, ErrNoValidBackends + return nil, flags, errors.ErrNoValidBackends } ncl, err := negative.ConfigLookup(c.NegativeCacheConfigs).Validate() diff --git a/pkg/config/loader_test.go b/pkg/config/loader_test.go index 8a128cb92..69d7edf8d 100644 --- a/pkg/config/loader_test.go +++ b/pkg/config/loader_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/trickstercache/trickster/v2/pkg/cache/evictionmethods" + "github.com/trickstercache/trickster/v2/pkg/errors" tlstest "github.com/trickstercache/trickster/v2/pkg/testutil/tls" ) @@ -620,7 +621,7 @@ func TestLoadConfigurationWarning2(t *testing.T) { func TestLoadEmptyArgs(t *testing.T) { a := []string{} _, _, err := Load("trickster-test", "0", a) - if err != ErrNoValidBackends { - t.Error("expected error:", ErrNoValidBackends) + if err != errors.ErrNoValidBackends { + t.Error("expected error:", errors.ErrNoValidBackends) } } diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index fafc47b8a..433e378f9 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -25,4 +25,13 @@ var ErrNilWriter = errors.New("nil writer") var ErrInvalidOptions = errors.New("invalid options") // ErrMissingPathconfig is an error for when a configuration is missing a path value -var ErrMissingPathconfig = errors.New("missing path config") +var ErrMissingPathConfig = errors.New("missing path config") + +// ErrInvalidPath is an error for when a configuration's path is invalid +var ErrInvalidPath = errors.New("invalid path value in config") + +// ErrInvalidMethod is an error for when a configuration's method is invalid +var ErrInvalidMethod = errors.New("invalid method value in config") + +// ErrNoValidBackends is an error for when not valid backends have been configured +var ErrNoValidBackends = errors.New("no valid backends configured") diff --git a/pkg/httpserver/httpserver.go b/pkg/httpserver/httpserver.go index de3d995b2..700b30f70 100644 --- a/pkg/httpserver/httpserver.go +++ b/pkg/httpserver/httpserver.go @@ -41,7 +41,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/observability/metrics" tr "github.com/trickstercache/trickster/v2/pkg/observability/tracing/registration" "github.com/trickstercache/trickster/v2/pkg/proxy/handlers" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) @@ -130,10 +130,14 @@ func applyConfig(conf, oldConf *config.Config, wg *sync.WaitGroup, logger *tl.Lo } // every config (re)load is a new router - r := router.NewRouter() - mr := http.NewServeMux() + r := lm.NewRouter() + mr := lm.NewRouter() + mr.SetMatchingScheme(0) // metrics router is exact-match only + + r.RegisterRoute(conf.Main.PingHandlerPath, nil, + []string{http.MethodGet, http.MethodHead}, false, + http.HandlerFunc((handlers.PingHandleFunc(conf)))) - r.HandleFunc(conf.Main.PingHandlerPath, handlers.PingHandleFunc(conf)).Methods(http.MethodGet) var caches = applyCachingConfig(conf, oldConf, logger, oldCaches) rh := handlers.ReloadHandleFunc(Serve, conf, wg, logger, caches, args) @@ -144,7 +148,12 @@ func applyConfig(conf, oldConf *config.Config, wg *sync.WaitGroup, logger *tl.Lo return err } - r.HandleFunc(conf.Main.PurgeKeyHandlerPath, handlers.PurgeKeyHandleFunc(conf, o)).Methods(http.MethodDelete) + if !strings.HasSuffix(conf.Main.PurgeKeyHandlerPath, "/") { + conf.Main.PurgeKeyHandlerPath += "/" + } + r.RegisterRoute(conf.Main.PurgeKeyHandlerPath, nil, + []string{http.MethodDelete}, true, + http.HandlerFunc(handlers.PurgeKeyHandleFunc(conf, o))) if hc != nil { hc.Shutdown() @@ -320,8 +329,9 @@ func validateConfig(conf *config.Config) error { caches[k] = nil } - r := router.NewRouter() - mr := http.NewServeMux() + r := lm.NewRouter() + mr := lm.NewRouter() + mr.SetMatchingScheme(0) // metrics router is exact-match only logger := tl.ConsoleLogger(conf.Logging.LogLevel) tracers, err := tr.RegisterAll(conf, logger, true) diff --git a/pkg/httpserver/listeners.go b/pkg/httpserver/listeners.go index 89b2f3e7b..23390812d 100644 --- a/pkg/httpserver/listeners.go +++ b/pkg/httpserver/listeners.go @@ -30,13 +30,15 @@ import ( "github.com/trickstercache/trickster/v2/pkg/proxy/handlers" "github.com/trickstercache/trickster/v2/pkg/proxy/listener" ttls "github.com/trickstercache/trickster/v2/pkg/proxy/tls" + "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) var lg = listener.NewListenerGroup() func applyListenerConfigs(conf, oldConf *config.Config, - router, reloadHandler http.Handler, metricsRouter *http.ServeMux, + router, reloadHandler http.Handler, metricsRouter router.Router, log *tl.Logger, tracers tracing.Tracers, o backends.Backends, wg *sync.WaitGroup, errorFunc func()) { @@ -137,8 +139,10 @@ func applyListenerConfigs(conf, oldConf *config.Config, (!hasOldMC || (conf.Metrics.ListenAddress != oldConf.Metrics.ListenAddress || conf.Metrics.ListenPort != oldConf.Metrics.ListenPort)) { lg.DrainAndClose("metricsListener", 0) - metricsRouter.Handle("/metrics", metrics.Handler()) - metricsRouter.HandleFunc(conf.Main.ConfigHandlerPath, handlers.ConfigHandleFunc(conf)) + metricsRouter.RegisterRoute("/metrics", nil, nil, + false, metrics.Handler()) + metricsRouter.RegisterRoute(conf.Main.ConfigHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.ConfigHandleFunc(conf))) if conf.Main.PprofServer == "both" || conf.Main.PprofServer == "metrics" { routing.RegisterPprofRoutes("metrics", metricsRouter, log) } @@ -147,12 +151,15 @@ func applyListenerConfigs(conf, oldConf *config.Config, conf.Metrics.ListenAddress, conf.Metrics.ListenPort, conf.Frontend.ConnectionsLimit, nil, metricsRouter, wg, nil, errorFunc, 0, log) } else { - metricsRouter.Handle("/metrics", metrics.Handler()) - metricsRouter.HandleFunc(conf.Main.ConfigHandlerPath, handlers.ConfigHandleFunc(conf)) + metricsRouter.RegisterRoute("/metrics", nil, nil, + false, metrics.Handler()) + metricsRouter.RegisterRoute(conf.Main.ConfigHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.ConfigHandleFunc(conf))) lg.UpdateRouter("metricsListener", metricsRouter) } - rr := http.NewServeMux() // serveMux router for the Reload port + rr := lm.NewRouter() // router for the Reload port + rr.SetMatchingScheme(0) // reload router is exact-match only // if the Reload HTTP port is configured, then set up the http listener instance if conf.ReloadConfig != nil && conf.ReloadConfig.ListenPort > 0 && @@ -160,9 +167,12 @@ func applyListenerConfigs(conf, oldConf *config.Config, conf.ReloadConfig.ListenPort != oldConf.ReloadConfig.ListenPort)) { wg.Add(1) lg.DrainAndClose("reloadListener", time.Millisecond*500) - rr.HandleFunc(conf.Main.ConfigHandlerPath, handlers.ConfigHandleFunc(conf)) - rr.Handle(conf.ReloadConfig.HandlerPath, reloadHandler) - rr.HandleFunc(conf.Main.PurgePathHandlerPath, handlers.PurgePathHandlerFunc(conf, &o)) + rr.RegisterRoute(conf.Main.ConfigHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.ConfigHandleFunc(conf))) + rr.RegisterRoute(conf.ReloadConfig.HandlerPath, nil, nil, + false, reloadHandler) + rr.RegisterRoute(conf.Main.PurgePathHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.PurgePathHandlerFunc(conf, &o))) if conf.Main.PprofServer == "both" || conf.Main.PprofServer == "reload" { routing.RegisterPprofRoutes("reload", rr, log) } @@ -170,9 +180,12 @@ func applyListenerConfigs(conf, oldConf *config.Config, conf.ReloadConfig.ListenAddress, conf.ReloadConfig.ListenPort, conf.Frontend.ConnectionsLimit, nil, rr, wg, nil, errorFunc, 0, log) } else { - rr.HandleFunc(conf.Main.ConfigHandlerPath, handlers.ConfigHandleFunc(conf)) - rr.Handle(conf.ReloadConfig.HandlerPath, reloadHandler) - rr.HandleFunc(conf.Main.PurgePathHandlerPath, handlers.PurgePathHandlerFunc(conf, &o)) + rr.RegisterRoute(conf.Main.ConfigHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.ConfigHandleFunc(conf))) + rr.RegisterRoute(conf.ReloadConfig.HandlerPath, nil, nil, + false, reloadHandler) + rr.RegisterRoute(conf.Main.PurgePathHandlerPath, nil, nil, + false, http.HandlerFunc(handlers.PurgePathHandlerFunc(conf, &o))) lg.UpdateRouter("reloadListener", rr) } } diff --git a/pkg/proxy/handlers/clickhouse/clickhouse.go b/pkg/proxy/handlers/clickhouse/clickhouse.go index af438c63b..95595ae49 100644 --- a/pkg/proxy/handlers/clickhouse/clickhouse.go +++ b/pkg/proxy/handlers/clickhouse/clickhouse.go @@ -24,7 +24,7 @@ import ( bo "github.com/trickstercache/trickster/v2/pkg/backends/options" co "github.com/trickstercache/trickster/v2/pkg/cache/options" "github.com/trickstercache/trickster/v2/pkg/cache/registration" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) @@ -57,8 +57,8 @@ func NewAcceleratorWithOptions(baseURL string, o *bo.Options, c *co.Options) (ht o.Scheme = u.Scheme o.Host = u.Host o.PathPrefix = u.Path - r := router.NewRouter() - cl, err := clickhouse.NewClient("default", o, router.NewRouter(), cache, nil, nil) + r := lm.NewRouter() + cl, err := clickhouse.NewClient("default", o, lm.NewRouter(), cache, nil, nil) if err != nil { return nil, err } diff --git a/pkg/proxy/handlers/influxdb/influxdb.go b/pkg/proxy/handlers/influxdb/influxdb.go index c79aeace1..94e5d0150 100644 --- a/pkg/proxy/handlers/influxdb/influxdb.go +++ b/pkg/proxy/handlers/influxdb/influxdb.go @@ -24,7 +24,7 @@ import ( bo "github.com/trickstercache/trickster/v2/pkg/backends/options" co "github.com/trickstercache/trickster/v2/pkg/cache/options" "github.com/trickstercache/trickster/v2/pkg/cache/registration" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) @@ -57,8 +57,8 @@ func NewAcceleratorWithOptions(baseURL string, o *bo.Options, c *co.Options) (ht o.Scheme = u.Scheme o.Host = u.Host o.PathPrefix = u.Path - r := router.NewRouter() - cl, err := influxdb.NewClient("default", o, router.NewRouter(), cache, nil, nil) + r := lm.NewRouter() + cl, err := influxdb.NewClient("default", o, lm.NewRouter(), cache, nil, nil) if err != nil { return nil, err } diff --git a/pkg/proxy/handlers/prometheus/prometheus.go b/pkg/proxy/handlers/prometheus/prometheus.go index ac0c70bdc..ce30cdea4 100644 --- a/pkg/proxy/handlers/prometheus/prometheus.go +++ b/pkg/proxy/handlers/prometheus/prometheus.go @@ -24,7 +24,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/backends/prometheus" co "github.com/trickstercache/trickster/v2/pkg/cache/options" "github.com/trickstercache/trickster/v2/pkg/cache/registration" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) @@ -57,8 +57,8 @@ func NewAcceleratorWithOptions(baseURL string, o *bo.Options, c *co.Options) (ht o.Scheme = u.Scheme o.Host = u.Host o.PathPrefix = u.Path - r := router.NewRouter() - cl, err := prometheus.NewClient("default", o, router.NewRouter(), cache, nil, nil) + r := lm.NewRouter() + cl, err := prometheus.NewClient("default", o, lm.NewRouter(), cache, nil, nil) if err != nil { return nil, err } diff --git a/pkg/proxy/handlers/purge.go b/pkg/proxy/handlers/purge.go index 50fa05921..15491671e 100644 --- a/pkg/proxy/handlers/purge.go +++ b/pkg/proxy/handlers/purge.go @@ -18,6 +18,7 @@ package handlers import ( "net/http" + "strings" "github.com/trickstercache/trickster/v2/pkg/backends" "github.com/trickstercache/trickster/v2/pkg/checksum/md5" @@ -25,14 +26,19 @@ import ( "github.com/trickstercache/trickster/v2/pkg/observability/logging" "github.com/trickstercache/trickster/v2/pkg/proxy/headers" "github.com/trickstercache/trickster/v2/pkg/proxy/request" - "github.com/trickstercache/trickster/v2/pkg/router" ) // PurgeHandleFunc purges an object from a cache based on key. func PurgeKeyHandleFunc(conf *config.Config, from backends.Backends) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, req *http.Request) { - params := router.Vars(req) - purgeFrom, purgeKey := params["backend"], params["key"] + vals := strings.Replace(req.URL.Path, conf.Main.PurgeKeyHandlerPath, "", 1) + parts := strings.Split(vals, "/") + if len(parts) != 2 { + http.NotFound(w, req) + return + } + purgeFrom := parts[0] + purgeKey := parts[1] fromBackend := from.Get(purgeFrom) if fromBackend == nil { w.Header().Set(headers.NameContentType, headers.ValueTextPlain) diff --git a/pkg/proxy/handlers/rpc/rpc.go b/pkg/proxy/handlers/rpc/rpc.go index 2489ac498..13c6bc314 100644 --- a/pkg/proxy/handlers/rpc/rpc.go +++ b/pkg/proxy/handlers/rpc/rpc.go @@ -24,7 +24,7 @@ import ( rpc "github.com/trickstercache/trickster/v2/pkg/backends/reverseproxycache" co "github.com/trickstercache/trickster/v2/pkg/cache/options" "github.com/trickstercache/trickster/v2/pkg/cache/registration" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/routing" ) @@ -57,7 +57,7 @@ func NewWithOptions(baseURL string, o *bo.Options, c *co.Options) (http.Handler, o.Scheme = u.Scheme o.Host = u.Host o.PathPrefix = u.Path - r := router.NewRouter() + r := lm.NewRouter() cl, err := rpc.NewClient("default", o, r, cache, nil, nil) if err != nil { return nil, err diff --git a/pkg/proxy/headers/headers.go b/pkg/proxy/headers/headers.go index 68595a184..61724efc7 100644 --- a/pkg/proxy/headers/headers.go +++ b/pkg/proxy/headers/headers.go @@ -36,6 +36,8 @@ const ( ValueApplicationFlux = "application/vnd.flux" // ValueChunked represents the HTTP Header Value of "chunked" ValueChunked = "chunked" + // ValueClose represents the HTTP Header Value of "close" + ValueClose = "close" // ValueMaxAge represents the HTTP Header Value of "max-age" ValueMaxAge = "max-age" // ValueMultipartFormData represents the HTTP Header Value of "multipart/form-data" diff --git a/pkg/proxy/listener/listener_test.go b/pkg/proxy/listener/listener_test.go index a3c6e0f93..b935ff591 100644 --- a/pkg/proxy/listener/listener_test.go +++ b/pkg/proxy/listener/listener_test.go @@ -35,7 +35,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/observability/tracing/exporters/stdout" "github.com/trickstercache/trickster/v2/pkg/proxy/errors" ph "github.com/trickstercache/trickster/v2/pkg/proxy/handlers" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" testutil "github.com/trickstercache/trickster/v2/pkg/testutil" tlstest "github.com/trickstercache/trickster/v2/pkg/testutil/tls" ) @@ -224,7 +224,7 @@ func TestListenerConnectionLimitWorks(t *testing.T) { } go func() { - http.Serve(l, router.NewRouter()) + http.Serve(l, lm.NewRouter()) }() if err != nil { diff --git a/pkg/proxy/methods/methods.go b/pkg/proxy/methods/methods.go index e7792a71e..78fbf89bf 100644 --- a/pkg/proxy/methods/methods.go +++ b/pkg/proxy/methods/methods.go @@ -17,7 +17,10 @@ // Package methods provides functionality for handling HTTP methods package methods -import "net/http" +import ( + "net/http" + "strings" +) const ( get uint16 = 1 << iota @@ -108,3 +111,9 @@ func MethodMask(methods ...string) uint16 { } return i } + +// IsValidMethod returns true if the provided method is recognized in methodsMap +func IsValidMethod(method string) bool { + _, ok := methodsMap[strings.ToUpper(method)] + return ok +} diff --git a/pkg/proxy/paths/options/options.go b/pkg/proxy/paths/options/options.go index e66e0ba1e..8895a62a9 100644 --- a/pkg/proxy/paths/options/options.go +++ b/pkg/proxy/paths/options/options.go @@ -220,7 +220,7 @@ func SetDefaults( p.ReqRewriter = ri } if len(p.Methods) == 0 { - p.Methods = []string{http.MethodGet, http.MethodHead} + p.Methods = []string{http.MethodGet} } p.Custom = make([]string, 0) for _, pm := range pathMembers { @@ -248,7 +248,7 @@ func SetDefaults( p.MatchType = matching.PathMatchTypeExact p.MatchTypeName = p.MatchType.String() } - paths[p.Path+"-"+strings.Join(p.Methods, "-")] = p + paths[p.Path] = p } return nil } diff --git a/pkg/router/LICENSE b/pkg/router/LICENSE deleted file mode 100644 index 1e604d6e7..000000000 --- a/pkg/router/LICENSE +++ /dev/null @@ -1,31 +0,0 @@ -Originally based on https://github.com/gorilla/mux (archived) - ------------------------------------------------------------------------- - -Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/router/lm/lm.go b/pkg/router/lm/lm.go new file mode 100644 index 000000000..d0aa50de5 --- /dev/null +++ b/pkg/router/lm/lm.go @@ -0,0 +1,242 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// package lm represents a simple Longest Match router +package lm + +import ( + "net/http" + "slices" + "sort" + "strings" + + "github.com/trickstercache/trickster/v2/pkg/errors" + "github.com/trickstercache/trickster/v2/pkg/proxy/headers" + meth "github.com/trickstercache/trickster/v2/pkg/proxy/methods" + "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/route" +) + +var _ router.Router = &lmRouter{} + +type lmRouter struct { + matchScheme router.MatchingScheme + routes route.HostRouteSetLookup +} + +func NewRouter() router.Router { + return &lmRouter{ + matchScheme: router.DefaultMatchingScheme, + routes: make(route.HostRouteSetLookup), + } +} + +var emptyHost = []string{""} +var defaultMethods = []string{http.MethodGet, http.MethodHead} + +func (rt *lmRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "*" { + if r.ProtoAtLeast(1, 1) { + w.Header().Set(headers.NameConnection, headers.ValueClose) + } + w.WriteHeader(http.StatusBadRequest) + return + } + rt.Handler(r).ServeHTTP(w, r) +} + +func (rt *lmRouter) RegisterRoute(path string, hosts, methods []string, + matchPrefix bool, handler http.Handler) error { + pl := len(path) + if pl == 0 { + return errors.ErrInvalidPath + } + if len(methods) == 0 { + methods = defaultMethods + } else { + for i, m := range methods { + if !meth.IsValidMethod(m) { + return errors.ErrInvalidMethod + } + methods[i] = strings.ToUpper(m) + } + } + if hosts == nil { + hosts = emptyHost + } + for _, h := range hosts { + hrc, ok := rt.routes[h] + if !ok || hrc == nil { + hrc = &route.HostRouteSet{ + ExactMatchRoutes: make(route.RouteLookupLookup), + PrefixMatchRoutes: make(route.PrefixRouteSets, 0, 16), + PrefixMatchRoutesLkp: make(route.PrefixRouteSetLookup), + } + rt.routes[h] = hrc + } + if !matchPrefix { + rl, ok := hrc.ExactMatchRoutes[path] + if rl == nil || !ok { + rl = make(route.RouteLookup) + hrc.ExactMatchRoutes[path] = rl + } + for _, m := range methods { + rl[m] = &route.Route{ + ExactMatch: true, + Method: m, + Host: h, + Path: path, + Handler: handler, + } + if m == http.MethodGet { + if _, ok := rl[http.MethodHead]; !ok { + rl[http.MethodHead] = &route.Route{ + ExactMatch: true, + Method: http.MethodHead, + Host: h, + Path: path, + Handler: handler, + } + } + } + } + continue + } + prc, ok := hrc.PrefixMatchRoutesLkp[path] + if prc == nil || !ok { + prc = &route.PrefixRouteSet{ + Path: path, + PathLen: pl, + RoutesByMethod: make(route.RouteLookup), + } + hrc.PrefixMatchRoutesLkp[path] = prc + if len(hrc.PrefixMatchRoutes) == 0 { + hrc.PrefixMatchRoutes = make(route.PrefixRouteSets, 0, 16) + } + hrc.PrefixMatchRoutes = append(hrc.PrefixMatchRoutes, prc) + } + for _, m := range methods { + prc.RoutesByMethod[m] = &route.Route{ + ExactMatch: true, + Method: m, + Host: h, + Path: path, + Handler: handler, + } + if m == http.MethodGet { + if _, ok := prc.RoutesByMethod[http.MethodHead]; !ok { + prc.RoutesByMethod[http.MethodHead] = &route.Route{ + ExactMatch: true, + Method: http.MethodHead, + Host: h, + Path: path, + Handler: handler, + } + } + } + } + } + rt.sort() + return nil +} + +// this sorts the prefix-match paths longest to shortest +func (rt *lmRouter) sort() { + for _, hrc := range rt.routes { + if len(hrc.PrefixMatchRoutes) == 0 { + continue + } + prs := prefixRouteSets(hrc.PrefixMatchRoutes) + sort.Sort(prs) + slices.Reverse(prs) + hrc.PrefixMatchRoutes = route.PrefixRouteSets(prs) + } +} + +func (rt *lmRouter) Handler(r *http.Request) http.Handler { + if rt.matchScheme&router.MatchHostname == router.MatchHostname { + host := r.Host + i := strings.Index(host, ":") + if i >= 0 { + host = host[0:i] + } + h := rt.matchByHost(r.Method, host, r.URL.Path) + if h != nil { + return h + } + } + h := rt.matchByHost(r.Method, "", r.URL.Path) + if h != nil { + return h + } + return notFoundHandler +} + +func (rt *lmRouter) matchByHost(method, host, path string) http.Handler { + if hrc, ok := rt.routes[host]; ok && hrc != nil { + if rs, ok := hrc.ExactMatchRoutes[path]; ok && rs != nil { + r, ok := rs[method] + if !ok || r == nil { + return methodNotAllowedHandler + } + return r.Handler + } + if !(rt.matchScheme&router.MatchPathPrefix == router.MatchPathPrefix) { + return nil + } + lp := len(path) + for _, prc := range hrc.PrefixMatchRoutes { + if prc.PathLen > lp { + continue + } + if strings.HasPrefix(path, prc.Path) { + r, ok := prc.RoutesByMethod[method] + if !ok || r == nil { + return methodNotAllowedHandler + } + return r.Handler + } + } + } + return nil +} + +func (rt *lmRouter) SetMatchingScheme(s router.MatchingScheme) { + rt.matchScheme = s +} + +func MethodNotAllowed(w http.ResponseWriter, r *http.Request) { + http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed) +} + +var methodNotAllowedHandler = http.HandlerFunc(MethodNotAllowed) +var notFoundHandler = http.HandlerFunc(http.NotFound) + +// prefixRouteSets allows the route.PrefixRouteSets to be sorted by path from +// longest-to-shortest using sort.Interface +type prefixRouteSets route.PrefixRouteSets + +func (prs prefixRouteSets) Len() int { + return len(prs) +} + +func (prs prefixRouteSets) Swap(i, j int) { + prs[i], prs[j] = prs[j], prs[i] +} + +func (prs prefixRouteSets) Less(i, j int) bool { + return prs[i].PathLen < prs[j].PathLen +} diff --git a/pkg/router/lm/lm_test.go b/pkg/router/lm/lm_test.go new file mode 100644 index 000000000..756d42922 --- /dev/null +++ b/pkg/router/lm/lm_test.go @@ -0,0 +1,234 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package lm + +import ( + "net/http" + "strings" + "testing" + + "github.com/trickstercache/trickster/v2/pkg/errors" + "github.com/trickstercache/trickster/v2/pkg/testutil/writer" +) + +const testPathExact1 = "/path/exact" +const testPathExact2 = "/path/exact/2" +const testPathPrefix1 = "/path/prefix" +const testPathPrefix2 = "/path/prefix/2" + +func TestRegisterRoute(t *testing.T) { + + const testPathExact1 = "/path1/exact" + + r := NewRouter().(*lmRouter) + r.RegisterRoute(testPathExact1, nil, nil, false, notFoundHandler) + + hrs, ok := r.routes[""] + if !ok || hrs == nil { + t.Fatal("expected non-nil route set") + } + rll, ok := hrs.ExactMatchRoutes[testPathExact1] + if !ok || rll == nil { + t.Fatal("expected non-nil route lookup") + } + + err := r.RegisterRoute("", nil, nil, false, notFoundHandler) + if err != errors.ErrInvalidPath { + t.Fatal("expected error for invalid path") + } + + err = r.RegisterRoute(testPathPrefix1, nil, []string{"invalidMethod"}, + false, notFoundHandler) + if err != errors.ErrInvalidMethod { + t.Fatal("expected error for invalid method") + } + + err = r.RegisterRoute(testPathPrefix1, nil, []string{http.MethodGet}, + true, notFoundHandler) + if err != nil { + t.Fatal(err) + } +} + +func TestHandler(t *testing.T) { + r := NewRouter().(*lmRouter) + r.RegisterRoute(testPathExact1, nil, nil, false, testResponse1Handler) + r.RegisterRoute(testPathPrefix2, []string{"example.com"}, nil, true, + testResponse2Handler) + r.RegisterRoute(testPathPrefix1, []string{"example.com"}, nil, true, + testResponse1Handler) + + req, _ := http.NewRequest(http.MethodGet, testPathExact1, nil) + req.Host = "example.com:8080" + h := r.Handler(req) + w := writer.NewWriter().(*writer.TestResponseWriter) + if h == nil { + t.Fatal("expected non-nil handler") + } + ok := serveAndVerifyTestResponse1(h, w, req) + if !ok { + t.Fatal("expected test response 1 handler") + } + + // POST request should fail with Method Not Allowed + req, _ = http.NewRequest(http.MethodPost, testPathExact1, nil) + req.Host = "example.com:8080" + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = verifyMethodNotAllowed(h, w, req) + if !ok { + t.Fatal("expected method not allowed handler") + } + + // request should fail with 404 Not Found + req, _ = http.NewRequest(http.MethodPost, testPathExact2, nil) + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = verifyNotFound(h, w, req) + if !ok { + t.Fatal("expected 404 not found handler") + } + + // Prefix Route 1 should pass with test response 1 + req, _ = http.NewRequest(http.MethodGet, testPathPrefix1+"/more/path", nil) + req.Host = "example.com:8080" + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = serveAndVerifyTestResponse1(h, w, req) + if !ok { + t.Fatal("expected test response 1 handler") + } + + // POST on Prefix Route 1 should fail with Method Not Allowed + req, _ = http.NewRequest(http.MethodPost, testPathPrefix1+"/more/path", nil) + req.Host = "example.com:8080" + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = verifyMethodNotAllowed(h, w, req) + if !ok { + t.Fatal("expected method not allowed handler") + } + + r.RegisterRoute(testPathExact2, []string{"example.com"}, nil, false, + testResponse2Handler) + req, _ = http.NewRequest(http.MethodGet, testPathExact2, nil) + req.Host = "example.com:8080" + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = verifyTestResponse2(h, w, req) + if !ok { + t.Fatal("expected test response 2 handler") + } + + r.SetMatchingScheme(0) + req, _ = http.NewRequest(http.MethodConnect, testPathPrefix1, nil) + req.Host = "example.com:8080" + h = r.Handler(req) + w.Reset() + if h == nil { + t.Fatal("expected non-nil handler") + } + ok = verifyNotFound(h, w, req) + if !ok { + t.Fatal("expected 404 not found handler") + } + +} + +func TestServeHTTP(t *testing.T) { + r := NewRouter().(*lmRouter) + r.RegisterRoute("/", nil, nil, true, testResponse1Handler) + w := writer.NewWriter().(*writer.TestResponseWriter) + req, _ := http.NewRequest(http.MethodGet, testPathPrefix1, nil) + req.RequestURI = "*" + r.ServeHTTP(w, req) + ok := verifyBadRequest(w) + if !ok { + t.Fatal("expected 400 bad request handler") + } + req, _ = http.NewRequest(http.MethodGet, testPathPrefix1, nil) + w.Reset() + r.ServeHTTP(w, req) + ok = verifyTestResponse1(w) + if !ok { + t.Fatal("expected test response 1 handler") + } +} + +func verifyNotFound(h http.Handler, w *writer.TestResponseWriter, + r *http.Request) bool { + h.ServeHTTP(w, r) + return w.StatusCode == http.StatusNotFound +} + +func verifyMethodNotAllowed(h http.Handler, w *writer.TestResponseWriter, + r *http.Request) bool { + h.ServeHTTP(w, r) + return w.StatusCode == http.StatusMethodNotAllowed +} + +const testResponse1Text = "test response 1" +const testResponse2Text = "test response 2" + +func testResponse1(w http.ResponseWriter, r *http.Request) { + http.Error(w, testResponse1Text, http.StatusOK) +} + +var testResponse1Handler = http.HandlerFunc(testResponse1) + +func serveAndVerifyTestResponse1(h http.Handler, w *writer.TestResponseWriter, + r *http.Request) bool { + h.ServeHTTP(w, r) + return verifyTestResponse1(w) +} + +func verifyTestResponse1(w *writer.TestResponseWriter) bool { + return w.StatusCode == http.StatusOK && + strings.TrimSpace(string(w.Bytes)) == testResponse1Text +} + +func testResponse2(w http.ResponseWriter, r *http.Request) { + http.Error(w, testResponse2Text, http.StatusOK) +} + +var testResponse2Handler = http.HandlerFunc(testResponse2) + +func verifyTestResponse2(h http.Handler, w *writer.TestResponseWriter, + r *http.Request) bool { + h.ServeHTTP(w, r) + return w.StatusCode == http.StatusOK && + strings.TrimSpace(string(w.Bytes)) == testResponse2Text +} + +func verifyBadRequest(w *writer.TestResponseWriter) bool { + return w.StatusCode == http.StatusBadRequest +} diff --git a/pkg/router/middleware.go b/pkg/router/middleware.go deleted file mode 100644 index 70856defe..000000000 --- a/pkg/router/middleware.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. -package router - -import ( - "net/http" - "strings" -) - -// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler. -// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed -// to it, and then calls the handler passed as parameter to the MiddlewareFunc. -type MiddlewareFunc func(http.Handler) http.Handler - -// middleware interface is anything which implements a MiddlewareFunc named Middleware. -type middleware interface { - Middleware(handler http.Handler) http.Handler -} - -// Middleware allows MiddlewareFunc to implement the middleware interface. -func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { - return mw(handler) -} - -// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. -func (r *router) Use(mwf ...MiddlewareFunc) { - for _, fn := range mwf { - r.middlewares = append(r.middlewares, fn) - } -} - -// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header -// on requests for routes that have an OPTIONS method matcher to all the method matchers on -// the route. Routes that do not explicitly handle OPTIONS requests will not be processed -// by the middleware. See examples for usage. -func CORSMethodMiddleware(r *router) MiddlewareFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - allMethods, err := getAllMethodsForRoute(r, req) - if err == nil { - for _, v := range allMethods { - if v == http.MethodOptions { - w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ",")) - } - } - } - - next.ServeHTTP(w, req) - }) - } -} - -// getAllMethodsForRoute returns all the methods from method matchers matching a given -// request. -func getAllMethodsForRoute(r *router, req *http.Request) ([]string, error) { - var allMethods []string - - for _, route := range r.routes { - var match RouteMatch - if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch { - methods, err := route.GetMethods() - if err != nil { - return nil, err - } - - allMethods = append(allMethods, methods...) - } - } - - return allMethods, nil -} diff --git a/pkg/router/regexp.go b/pkg/router/regexp.go deleted file mode 100644 index 7e9ba1b6f..000000000 --- a/pkg/router/regexp.go +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. -package router - -import ( - "bytes" - "fmt" - "net/http" - "net/url" - "regexp" - "strconv" - "strings" -) - -type routeRegexpOptions struct { - strictSlash bool - useEncodedPath bool -} - -type regexpType int - -const ( - regexpTypePath regexpType = iota - regexpTypeHost - regexpTypePrefix - regexpTypeQuery -) - -// newRouteRegexp parses a route template and returns a routeRegexp, -// used to match a host, a path or a query string. -// -// It will extract named variables, assemble a regexp to be matched, create -// a "reverse" template to build URLs and compile regexps to validate variable -// values used in URL building. -// -// Previously we accepted only Python-like identifiers for variable -// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that -// name and pattern can't be empty, and names can't contain a colon. -func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*routeRegexp, error) { - // Check if it is well-formed. - idxs, errBraces := braceIndices(tpl) - if errBraces != nil { - return nil, errBraces - } - // Backup the original. - template := tpl - // Now let's parse it. - defaultPattern := "[^/]+" - if typ == regexpTypeQuery { - defaultPattern = ".*" - } else if typ == regexpTypeHost { - defaultPattern = "[^.]+" - } - // Only match strict slash if not matching - if typ != regexpTypePath { - options.strictSlash = false - } - // Set a flag for strictSlash. - endSlash := false - if options.strictSlash && strings.HasSuffix(tpl, "/") { - tpl = tpl[:len(tpl)-1] - endSlash = true - } - varsN := make([]string, len(idxs)/2) - varsR := make([]*regexp.Regexp, len(idxs)/2) - pattern := bytes.NewBufferString("") - pattern.WriteByte('^') - reverse := bytes.NewBufferString("") - var end int - var err error - for i := 0; i < len(idxs); i += 2 { - // Set all values we are interested in. - raw := tpl[end:idxs[i]] - end = idxs[i+1] - parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2) - name := parts[0] - patt := defaultPattern - if len(parts) == 2 { - patt = parts[1] - } - // Name or pattern can't be empty. - if name == "" || patt == "" { - return nil, fmt.Errorf("router: missing name or pattern in %q", - tpl[idxs[i]:end]) - } - // Build the regexp pattern. - fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt) - - // Build the reverse template. - fmt.Fprintf(reverse, "%s%%s", raw) - - // Append variable name and compiled pattern. - varsN[i/2] = name - varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt)) - if err != nil { - return nil, err - } - } - // Add the remaining. - raw := tpl[end:] - pattern.WriteString(regexp.QuoteMeta(raw)) - if options.strictSlash { - pattern.WriteString("[/]?") - } - if typ == regexpTypeQuery { - // Add the default pattern if the query value is empty - if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" { - pattern.WriteString(defaultPattern) - } - } - if typ != regexpTypePrefix { - pattern.WriteByte('$') - } - - var wildcardHostPort bool - if typ == regexpTypeHost { - if !strings.Contains(pattern.String(), ":") { - wildcardHostPort = true - } - } - reverse.WriteString(raw) - if endSlash { - reverse.WriteByte('/') - } - // Compile full regexp. - reg, errCompile := regexp.Compile(pattern.String()) - if errCompile != nil { - return nil, errCompile - } - - // Check for capturing groups which used to work in older versions - if reg.NumSubexp() != len(idxs)/2 { - panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) + - "Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)") - } - - // Done! - return &routeRegexp{ - template: template, - regexpType: typ, - options: options, - regexp: reg, - reverse: reverse.String(), - varsN: varsN, - varsR: varsR, - wildcardHostPort: wildcardHostPort, - }, nil -} - -// routeRegexp stores a regexp to match a host or path and information to -// collect and validate route variables. -type routeRegexp struct { - // The unmodified template. - template string - // The type of match - regexpType regexpType - // Options for matching - options routeRegexpOptions - // Expanded regexp. - regexp *regexp.Regexp - // Reverse template. - reverse string - // Variable names. - varsN []string - // Variable regexps (validators). - varsR []*regexp.Regexp - // Wildcard host-port (no strict port match in hostname) - wildcardHostPort bool -} - -// Match matches the regexp against the URL host or path. -func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool { - if r.regexpType == regexpTypeHost { - host := getHost(req) - if r.wildcardHostPort { - // Don't be strict on the port match - if i := strings.Index(host, ":"); i != -1 { - host = host[:i] - } - } - return r.regexp.MatchString(host) - } - - if r.regexpType == regexpTypeQuery { - return r.matchQueryString(req) - } - path := req.URL.Path - if r.options.useEncodedPath { - path = req.URL.EscapedPath() - } - return r.regexp.MatchString(path) -} - -// url builds a URL part using the given values. -func (r *routeRegexp) url(values map[string]string) (string, error) { - urlValues := make([]interface{}, len(r.varsN)) - for k, v := range r.varsN { - value, ok := values[v] - if !ok { - return "", fmt.Errorf("router: missing route variable %q", v) - } - if r.regexpType == regexpTypeQuery { - value = url.QueryEscape(value) - } - urlValues[k] = value - } - rv := fmt.Sprintf(r.reverse, urlValues...) - if !r.regexp.MatchString(rv) { - // The URL is checked against the full regexp, instead of checking - // individual variables. This is faster but to provide a good error - // message, we check individual regexps if the URL doesn't match. - for k, v := range r.varsN { - if !r.varsR[k].MatchString(values[v]) { - return "", fmt.Errorf( - "router: variable %q doesn't match, expected %q", values[v], - r.varsR[k].String()) - } - } - } - return rv, nil -} - -// getURLQuery returns a single query parameter from a request URL. -// For a URL with foo=bar&baz=ding, we return only the relevant key -// value pair for the routeRegexp. -func (r *routeRegexp) getURLQuery(req *http.Request) string { - if r.regexpType != regexpTypeQuery { - return "" - } - templateKey := strings.SplitN(r.template, "=", 2)[0] - val, ok := findFirstQueryKey(req.URL.RawQuery, templateKey) - if ok { - return templateKey + "=" + val - } - return "" -} - -// findFirstQueryKey returns the same result as (*url.URL).Query()[key][0]. -// If key was not found, empty string and false is returned. -func findFirstQueryKey(rawQuery, key string) (value string, ok bool) { - query := []byte(rawQuery) - for len(query) > 0 { - foundKey := query - if i := bytes.IndexAny(foundKey, "&;"); i >= 0 { - foundKey, query = foundKey[:i], foundKey[i+1:] - } else { - query = query[:0] - } - if len(foundKey) == 0 { - continue - } - var value []byte - if i := bytes.IndexByte(foundKey, '='); i >= 0 { - foundKey, value = foundKey[:i], foundKey[i+1:] - } - if len(foundKey) < len(key) { - // Cannot possibly be key. - continue - } - keyString, err := url.QueryUnescape(string(foundKey)) - if err != nil { - continue - } - if keyString != key { - continue - } - valueString, err := url.QueryUnescape(string(value)) - if err != nil { - continue - } - return valueString, true - } - return "", false -} - -func (r *routeRegexp) matchQueryString(req *http.Request) bool { - return r.regexp.MatchString(r.getURLQuery(req)) -} - -// braceIndices returns the first level curly brace indices from a string. -// It returns an error in case of unbalanced braces. -func braceIndices(s string) ([]int, error) { - var level, idx int - var idxs []int - for i := 0; i < len(s); i++ { - switch s[i] { - case '{': - if level++; level == 1 { - idx = i - } - case '}': - if level--; level == 0 { - idxs = append(idxs, idx, i+1) - } else if level < 0 { - return nil, fmt.Errorf("router: unbalanced braces in %q", s) - } - } - } - if level != 0 { - return nil, fmt.Errorf("router: unbalanced braces in %q", s) - } - return idxs, nil -} - -// varGroupName builds a capturing group name for the indexed variable. -func varGroupName(idx int) string { - return "v" + strconv.Itoa(idx) -} - -// ---------------------------------------------------------------------------- -// routeRegexpGroup -// ---------------------------------------------------------------------------- - -// routeRegexpGroup groups the route matchers that carry variables. -type routeRegexpGroup struct { - host *routeRegexp - path *routeRegexp - queries []*routeRegexp -} - -// setMatch extracts the variables from the URL once a route matches. -func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { - // Store host variables. - if v.host != nil { - host := getHost(req) - if v.host.wildcardHostPort { - // Don't be strict on the port match - if i := strings.Index(host, ":"); i != -1 { - host = host[:i] - } - } - matches := v.host.regexp.FindStringSubmatchIndex(host) - if len(matches) > 0 { - extractVars(host, matches, v.host.varsN, m.Vars) - } - } - path := req.URL.Path - if r.useEncodedPath { - path = req.URL.EscapedPath() - } - // Store path variables. - if v.path != nil { - matches := v.path.regexp.FindStringSubmatchIndex(path) - if len(matches) > 0 { - extractVars(path, matches, v.path.varsN, m.Vars) - // Check if we should redirect. - if v.path.options.strictSlash { - p1 := strings.HasSuffix(path, "/") - p2 := strings.HasSuffix(v.path.template, "/") - if p1 != p2 { - u, _ := url.Parse(req.URL.String()) - if p1 { - u.Path = u.Path[:len(u.Path)-1] - } else { - u.Path += "/" - } - m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently) - } - } - } - } - // Store query string variables. - for _, q := range v.queries { - queryURL := q.getURLQuery(req) - matches := q.regexp.FindStringSubmatchIndex(queryURL) - if len(matches) > 0 { - extractVars(queryURL, matches, q.varsN, m.Vars) - } - } -} - -// getHost tries its best to return the request host. -// According to section 14.23 of RFC 2616 the Host header -// can include the port number if the default value of 80 is not used. -func getHost(r *http.Request) string { - if r.URL.IsAbs() { - return r.URL.Host - } - return r.Host -} - -func extractVars(input string, matches []int, names []string, output map[string]string) { - for i, name := range names { - output[name] = input[matches[2*i+2]:matches[2*i+3]] - } -} diff --git a/pkg/router/regexp_test.go b/pkg/router/regexp_test.go deleted file mode 100644 index d1351bcf5..000000000 --- a/pkg/router/regexp_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. -package router - -import ( - "net/url" - "reflect" - "strconv" - "testing" -) - -func Test_findFirstQueryKey(t *testing.T) { - tests := []string{ - "a=1&b=2", - "a=1&a=2&a=banana", - "ascii=%3Ckey%3A+0x90%3E", - "a=1;b=2", - "a=1&a=2;a=banana", - "a==", - "a=%2", - "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", - "a=1& ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;&a=5", - "a=xxxxxxxxxxxxxxxx&b=YYYYYYYYYYYYYYY&c=ppppppppppppppppppp&f=ttttttttttttttttt&a=uuuuuuuuuuuuu", - } - for _, query := range tests { - t.Run(query, func(t *testing.T) { - // Check against url.ParseQuery, ignoring the error. - all, _ := url.ParseQuery(query) - for key, want := range all { - t.Run(key, func(t *testing.T) { - got, ok := findFirstQueryKey(query, key) - if !ok { - t.Error("Did not get expected key", key) - } - if !reflect.DeepEqual(got, want[0]) { - t.Errorf("findFirstQueryKey(%s,%s) = %v, want %v", query, key, got, want[0]) - } - }) - } - }) - } -} - -func Benchmark_findQueryKey(b *testing.B) { - tests := []string{ - "a=1&b=2", - "ascii=%3Ckey%3A+0x90%3E", - "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", - "a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu", - "a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=", - } - for i, query := range tests { - b.Run(strconv.Itoa(i), func(b *testing.B) { - // Check against url.ParseQuery, ignoring the error. - all, _ := url.ParseQuery(query) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - for key := range all { - _, _ = findFirstQueryKey(query, key) - } - } - }) - } -} - -func Benchmark_findQueryKeyGoLib(b *testing.B) { - tests := []string{ - "a=1&b=2", - "ascii=%3Ckey%3A+0x90%3E", - "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", - "a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu", - "a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=", - } - for i, query := range tests { - b.Run(strconv.Itoa(i), func(b *testing.B) { - // Check against url.ParseQuery, ignoring the error. - all, _ := url.ParseQuery(query) - var u url.URL - u.RawQuery = query - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - for key := range all { - v := u.Query()[key] - if len(v) > 0 { - _ = v[0] - } - } - } - }) - } -} diff --git a/pkg/router/route.go b/pkg/router/route.go deleted file mode 100644 index cce6938eb..000000000 --- a/pkg/router/route.go +++ /dev/null @@ -1,734 +0,0 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. -package router - -import ( - "errors" - "fmt" - "net/http" - "net/url" - "regexp" - "strings" -) - -// Route stores information to match a request and build URLs. -type Route struct { - // Request handler for the route. - handler http.Handler - // If true, this route never matches: it is only used to build URLs. - buildOnly bool - // The name used to build URLs. - name string - // Error resulted from building a route. - err error - - // "global" reference to all named routes - namedRoutes map[string]*Route - - // config possibly passed in from `Router` - routeConf -} - -// SkipClean reports whether path cleaning is enabled for this route via -// Router.SkipClean. -func (r *Route) SkipClean() bool { - return r.skipClean -} - -// Match matches the route against the request. -func (r *Route) Match(req *http.Request, match *RouteMatch) bool { - if r.buildOnly || r.err != nil { - return false - } - - var matchErr error - - // Match everything. - for _, m := range r.matchers { - if matched := m.Match(req, match); !matched { - if _, ok := m.(methodMatcher); ok { - matchErr = ErrMethodMismatch - continue - } - - // Ignore ErrNotFound errors. These errors arise from match call - // to Subrouters. - // - // This prevents subsequent matching subrouters from failing to - // run middleware. If not ignored, the middleware would see a - // non-nil MatchErr and be skipped, even when there was a - // matching route. - if match.MatchErr == ErrNotFound { - match.MatchErr = nil - } - - matchErr = nil - return false - } - } - - if matchErr != nil { - match.MatchErr = matchErr - return false - } - - if match.MatchErr == ErrMethodMismatch && r.handler != nil { - // We found a route which matches request method, clear MatchErr - match.MatchErr = nil - // Then override the mis-matched handler - match.Handler = r.handler - } - - // Yay, we have a match. Let's collect some info about it. - if match.Route == nil { - match.Route = r - } - if match.Handler == nil { - match.Handler = r.handler - } - if match.Vars == nil { - match.Vars = make(map[string]string) - } - - // Set variables. - r.regexp.setMatch(req, match, r) - return true -} - -// ---------------------------------------------------------------------------- -// Route attributes -// ---------------------------------------------------------------------------- - -// GetError returns an error resulted from building the route, if any. -func (r *Route) GetError() error { - return r.err -} - -// BuildOnly sets the route to never match: it is only used to build URLs. -func (r *Route) BuildOnly() *Route { - r.buildOnly = true - return r -} - -// Handler -------------------------------------------------------------------- - -// Handler sets a handler for the route. -func (r *Route) Handler(handler http.Handler) *Route { - if r.err == nil { - r.handler = handler - } - return r -} - -// HandlerFunc sets a handler function for the route. -func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route { - return r.Handler(http.HandlerFunc(f)) -} - -// GetHandler returns the handler for the route, if any. -func (r *Route) GetHandler() http.Handler { - return r.handler -} - -// Name ----------------------------------------------------------------------- - -// Name sets the name for the route, used to build URLs. -// It is an error to call Name more than once on a route. -func (r *Route) Name(name string) *Route { - if r.name != "" { - r.err = fmt.Errorf("router: route already has name %q, can't set %q", - r.name, name) - } - if r.err == nil { - r.name = name - r.namedRoutes[name] = r - } - return r -} - -// GetName returns the name for the route, if any. -func (r *Route) GetName() string { - return r.name -} - -// ---------------------------------------------------------------------------- -// Matchers -// ---------------------------------------------------------------------------- - -// matcher types try to match a request. -type matcher interface { - Match(*http.Request, *RouteMatch) bool -} - -// addMatcher adds a matcher to the route. -func (r *Route) addMatcher(m matcher) *Route { - if r.err == nil { - r.matchers = append(r.matchers, m) - } - return r -} - -// addRegexpMatcher adds a host or path matcher and builder to a route. -func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error { - if r.err != nil { - return r.err - } - if typ == regexpTypePath || typ == regexpTypePrefix { - if len(tpl) > 0 && tpl[0] != '/' { - return fmt.Errorf("router: path must start with a slash, got %q", tpl) - } - if r.regexp.path != nil { - tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl - } - } - rr, err := newRouteRegexp(tpl, typ, routeRegexpOptions{ - strictSlash: r.strictSlash, - useEncodedPath: r.useEncodedPath, - }) - if err != nil { - return err - } - for _, q := range r.regexp.queries { - if err = uniqueVars(rr.varsN, q.varsN); err != nil { - return err - } - } - if typ == regexpTypeHost { - if r.regexp.path != nil { - if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { - return err - } - } - r.regexp.host = rr - } else { - if r.regexp.host != nil { - if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil { - return err - } - } - if typ == regexpTypeQuery { - r.regexp.queries = append(r.regexp.queries, rr) - } else { - r.regexp.path = rr - } - } - r.addMatcher(rr) - return nil -} - -// Headers -------------------------------------------------------------------- - -// headerMatcher matches the request against header values. -type headerMatcher map[string]string - -func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool { - return matchMapWithString(m, r.Header, true) -} - -// Headers adds a matcher for request header values. -// It accepts a sequence of key/value pairs to be matched. For example: -// -// r := router.NewRouter() -// r.Headers("Content-Type", "application/json", -// "X-Requested-With", "XMLHttpRequest") -// -// The above route will only match if both request header values match. -// If the value is an empty string, it will match any value if the key is set. -func (r *Route) Headers(pairs ...string) *Route { - if r.err == nil { - var headers map[string]string - headers, r.err = mapFromPairsToString(pairs...) - return r.addMatcher(headerMatcher(headers)) - } - return r -} - -// headerRegexMatcher matches the request against the route given a regex for the header -type headerRegexMatcher map[string]*regexp.Regexp - -func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool { - return matchMapWithRegex(m, r.Header, true) -} - -// HeadersRegexp accepts a sequence of key/value pairs, where the value has regex -// support. For example: -// -// r := router.NewRouter() -// r.HeadersRegexp("Content-Type", "application/(text|json)", -// "X-Requested-With", "XMLHttpRequest") -// -// The above route will only match if both the request header matches both regular expressions. -// If the value is an empty string, it will match any value if the key is set. -// Use the start and end of string anchors (^ and $) to match an exact value. -func (r *Route) HeadersRegexp(pairs ...string) *Route { - if r.err == nil { - var headers map[string]*regexp.Regexp - headers, r.err = mapFromPairsToRegex(pairs...) - return r.addMatcher(headerRegexMatcher(headers)) - } - return r -} - -// Host ----------------------------------------------------------------------- - -// Host adds a matcher for the URL host. -// It accepts a template with zero or more URL variables enclosed by {}. -// Variables can define an optional regexp pattern to be matched: -// -// - {name} matches anything until the next dot. -// -// - {name:pattern} matches the given regexp pattern. -// -// For example: -// -// r := router.NewRouter() -// r.Host("www.example.com") -// r.Host("{subdomain}.domain.com") -// r.Host("{subdomain:[a-z]+}.domain.com") -// -// Variable names must be unique in a given route. They can be retrieved -// calling router.Vars(request). -func (r *Route) Host(tpl string) *Route { - r.err = r.addRegexpMatcher(tpl, regexpTypeHost) - return r -} - -// MatcherFunc ---------------------------------------------------------------- - -// MatcherFunc is the function signature used by custom matchers. -type MatcherFunc func(*http.Request, *RouteMatch) bool - -// Match returns the match for a given request. -func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool { - return m(r, match) -} - -// MatcherFunc adds a custom function to be used as request matcher. -func (r *Route) MatcherFunc(f MatcherFunc) *Route { - return r.addMatcher(f) -} - -// Methods -------------------------------------------------------------------- - -// methodMatcher matches the request against HTTP methods. -type methodMatcher []string - -func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool { - return matchInArray(m, r.Method) -} - -// Methods adds a matcher for HTTP methods. -// It accepts a sequence of one or more methods to be matched, e.g.: -// "GET", "POST", "PUT". -func (r *Route) Methods(methods ...string) *Route { - for k, v := range methods { - methods[k] = strings.ToUpper(v) - } - return r.addMatcher(methodMatcher(methods)) -} - -// Path ----------------------------------------------------------------------- - -// Path adds a matcher for the URL path. -// It accepts a template with zero or more URL variables enclosed by {}. The -// template must start with a "/". -// Variables can define an optional regexp pattern to be matched: -// -// - {name} matches anything until the next slash. -// -// - {name:pattern} matches the given regexp pattern. -// -// For example: -// -// r := router.NewRouter() -// r.Path("/products/").Handler(ProductsHandler) -// r.Path("/products/{key}").Handler(ProductsHandler) -// r.Path("/articles/{category}/{id:[0-9]+}"). -// Handler(ArticleHandler) -// -// Variable names must be unique in a given route. They can be retrieved -// calling router.Vars(request). -func (r *Route) Path(tpl string) *Route { - r.err = r.addRegexpMatcher(tpl, regexpTypePath) - return r -} - -// PathPrefix ----------------------------------------------------------------- - -// PathPrefix adds a matcher for the URL path prefix. This matches if the given -// template is a prefix of the full URL path. See Route.Path() for details on -// the tpl argument. -// -// Note that it does not treat slashes specially ("/foobar/" will be matched by -// the prefix "/foo") so you may want to use a trailing slash here. -// -// Also note that the setting of Router.StrictSlash() has no effect on routes -// with a PathPrefix matcher. -func (r *Route) PathPrefix(tpl string) *Route { - r.err = r.addRegexpMatcher(tpl, regexpTypePrefix) - return r -} - -// Query ---------------------------------------------------------------------- - -// Queries adds a matcher for URL query values. -// It accepts a sequence of key/value pairs. Values may define variables. -// For example: -// -// r := router.NewRouter() -// r.Queries("foo", "bar", "id", "{id:[0-9]+}") -// -// The above route will only match if the URL contains the defined queries -// values, e.g.: ?foo=bar&id=42. -// -// If the value is an empty string, it will match any value if the key is set. -// -// Variables can define an optional regexp pattern to be matched: -// -// - {name} matches anything until the next slash. -// -// - {name:pattern} matches the given regexp pattern. -func (r *Route) Queries(pairs ...string) *Route { - length := len(pairs) - if length%2 != 0 { - r.err = fmt.Errorf( - "router: number of parameters must be multiple of 2, got %v", pairs) - return nil - } - for i := 0; i < length; i += 2 { - if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], regexpTypeQuery); r.err != nil { - return r - } - } - - return r -} - -// Schemes -------------------------------------------------------------------- - -// schemeMatcher matches the request against URL schemes. -type schemeMatcher []string - -func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { - scheme := r.URL.Scheme - // https://golang.org/pkg/net/http/#Request - // "For [most] server requests, fields other than Path and RawQuery will be - // empty." - // Since we're an http muxer, the scheme is either going to be http or https - // though, so we can just set it based on the tls termination state. - if scheme == "" { - if r.TLS == nil { - scheme = "http" - } else { - scheme = "https" - } - } - return matchInArray(m, scheme) -} - -// Schemes adds a matcher for URL schemes. -// It accepts a sequence of schemes to be matched, e.g.: "http", "https". -// If the request's URL has a scheme set, it will be matched against. -// Generally, the URL scheme will only be set if a previous handler set it. -// If unset, the scheme will be determined based on the request's TLS -// termination state. -// The first argument to Schemes will be used when constructing a route URL. -func (r *Route) Schemes(schemes ...string) *Route { - for k, v := range schemes { - schemes[k] = strings.ToLower(v) - } - if len(schemes) > 0 { - r.buildScheme = schemes[0] - } - return r.addMatcher(schemeMatcher(schemes)) -} - -// BuildVarsFunc -------------------------------------------------------------- - -// BuildVarsFunc is the function signature used by custom build variable -// functions (which can modify route variables before a route's URL is built). -type BuildVarsFunc func(map[string]string) map[string]string - -// BuildVarsFunc adds a custom function to be used to modify build variables -// before a route's URL is built. -func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { - if r.buildVarsFunc != nil { - // compose the old and new functions - old := r.buildVarsFunc - r.buildVarsFunc = func(m map[string]string) map[string]string { - return f(old(m)) - } - } else { - r.buildVarsFunc = f - } - return r -} - -// Subrouter ------------------------------------------------------------------ - -// Subrouter creates a subrouter for the route. -// -// It will test the inner routes only if the parent route matched. For example: -// -// r := router.NewRouter() -// s := r.Host("www.example.com").Subrouter() -// s.HandleFunc("/products/", ProductsHandler) -// s.HandleFunc("/products/{key}", ProductHandler) -// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) -// -// Here, the routes registered in the subrouter won't be tested if the host -// doesn't match. -func (r *Route) Subrouter() http.Handler { - // initialize a subrouter with a copy of the parent route's configuration - router := &router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes} - r.addMatcher(router) - return router -} - -// ---------------------------------------------------------------------------- -// URL building -// ---------------------------------------------------------------------------- - -// URL builds a URL for the route. -// -// It accepts a sequence of key/value pairs for the route variables. For -// example, given this route: -// -// r := router.NewRouter() -// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). -// Name("article") -// -// ...a URL for it can be built using: -// -// url, err := r.Get("article").URL("category", "technology", "id", "42") -// -// ...which will return an url.URL with the following path: -// -// "/articles/technology/42" -// -// This also works for host variables: -// -// r := router.NewRouter() -// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). -// Host("{subdomain}.domain.com"). -// Name("article") -// -// // url.String() will be "http://news.domain.com/articles/technology/42" -// url, err := r.Get("article").URL("subdomain", "news", -// "category", "technology", -// "id", "42") -// -// The scheme of the resulting url will be the first argument that was passed to Schemes: -// -// // url.String() will be "https://example.com" -// r := router.NewRouter() -// url, err := r.Host("example.com") -// .Schemes("https", "http").URL() -// -// All variables defined in the route are required, and their values must -// conform to the corresponding patterns. -func (r *Route) URL(pairs ...string) (*url.URL, error) { - if r.err != nil { - return nil, r.err - } - values, err := r.prepareVars(pairs...) - if err != nil { - return nil, err - } - var scheme, host, path string - queries := make([]string, 0, len(r.regexp.queries)) - if r.regexp.host != nil { - if host, err = r.regexp.host.url(values); err != nil { - return nil, err - } - scheme = "http" - if r.buildScheme != "" { - scheme = r.buildScheme - } - } - if r.regexp.path != nil { - if path, err = r.regexp.path.url(values); err != nil { - return nil, err - } - } - for _, q := range r.regexp.queries { - var query string - if query, err = q.url(values); err != nil { - return nil, err - } - queries = append(queries, query) - } - return &url.URL{ - Scheme: scheme, - Host: host, - Path: path, - RawQuery: strings.Join(queries, "&"), - }, nil -} - -// URLHost builds the host part of the URL for a route. See Route.URL(). -// -// The route must have a host defined. -func (r *Route) URLHost(pairs ...string) (*url.URL, error) { - if r.err != nil { - return nil, r.err - } - if r.regexp.host == nil { - return nil, errors.New("router: route doesn't have a host") - } - values, err := r.prepareVars(pairs...) - if err != nil { - return nil, err - } - host, err := r.regexp.host.url(values) - if err != nil { - return nil, err - } - u := &url.URL{ - Scheme: "http", - Host: host, - } - if r.buildScheme != "" { - u.Scheme = r.buildScheme - } - return u, nil -} - -// URLPath builds the path part of the URL for a route. See Route.URL(). -// -// The route must have a path defined. -func (r *Route) URLPath(pairs ...string) (*url.URL, error) { - if r.err != nil { - return nil, r.err - } - if r.regexp.path == nil { - return nil, errors.New("router: route doesn't have a path") - } - values, err := r.prepareVars(pairs...) - if err != nil { - return nil, err - } - path, err := r.regexp.path.url(values) - if err != nil { - return nil, err - } - return &url.URL{ - Path: path, - }, nil -} - -// GetPathTemplate returns the template used to build the -// route match. -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if the route does not define a path. -func (r *Route) GetPathTemplate() (string, error) { - if r.err != nil { - return "", r.err - } - if r.regexp.path == nil { - return "", errors.New("router: route doesn't have a path") - } - return r.regexp.path.template, nil -} - -// GetPathRegexp returns the expanded regular expression used to match route path. -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if the route does not define a path. -func (r *Route) GetPathRegexp() (string, error) { - if r.err != nil { - return "", r.err - } - if r.regexp.path == nil { - return "", errors.New("router: route does not have a path") - } - return r.regexp.path.regexp.String(), nil -} - -// GetQueriesRegexp returns the expanded regular expressions used to match the -// route queries. -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if the route does not have queries. -func (r *Route) GetQueriesRegexp() ([]string, error) { - if r.err != nil { - return nil, r.err - } - if r.regexp.queries == nil { - return nil, errors.New("router: route doesn't have queries") - } - queries := make([]string, 0, len(r.regexp.queries)) - for _, query := range r.regexp.queries { - queries = append(queries, query.regexp.String()) - } - return queries, nil -} - -// GetQueriesTemplates returns the templates used to build the -// query matching. -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if the route does not define queries. -func (r *Route) GetQueriesTemplates() ([]string, error) { - if r.err != nil { - return nil, r.err - } - if r.regexp.queries == nil { - return nil, errors.New("router: route doesn't have queries") - } - queries := make([]string, 0, len(r.regexp.queries)) - for _, query := range r.regexp.queries { - queries = append(queries, query.template) - } - return queries, nil -} - -// GetMethods returns the methods the route matches against -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if route does not have methods. -func (r *Route) GetMethods() ([]string, error) { - if r.err != nil { - return nil, r.err - } - for _, m := range r.matchers { - if methods, ok := m.(methodMatcher); ok { - return []string(methods), nil - } - } - return nil, errors.New("router: route doesn't have methods") -} - -// GetHostTemplate returns the template used to build the -// route match. -// This is useful for building simple REST API documentation and for instrumentation -// against third-party services. -// An error will be returned if the route does not define a host. -func (r *Route) GetHostTemplate() (string, error) { - if r.err != nil { - return "", r.err - } - if r.regexp.host == nil { - return "", errors.New("router: route doesn't have a host") - } - return r.regexp.host.template, nil -} - -// prepareVars converts the route variable pairs into a map. If the route has a -// BuildVarsFunc, it is invoked. -func (r *Route) prepareVars(pairs ...string) (map[string]string, error) { - m, err := mapFromPairsToString(pairs...) - if err != nil { - return nil, err - } - return r.buildVars(m), nil -} - -func (r *Route) buildVars(m map[string]string) map[string]string { - if r.buildVarsFunc != nil { - m = r.buildVarsFunc(m) - } - return m -} diff --git a/pkg/router/route/route.go b/pkg/router/route/route.go new file mode 100644 index 000000000..9a1a66a2e --- /dev/null +++ b/pkg/router/route/route.go @@ -0,0 +1,50 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// package route provides a Route data structure for Request Routing +package route + +import "net/http" + +type Route struct { + ExactMatch bool + Method string + Host string + Path string + Handler http.Handler +} + +type Routes []*Route + +type RouteLookup map[string]*Route +type RouteLookupLookup map[string]RouteLookup + +type PrefixRouteSet struct { + Path string + PathLen int + RoutesByMethod RouteLookup +} + +type PrefixRouteSets []*PrefixRouteSet +type PrefixRouteSetLookup map[string]*PrefixRouteSet + +type HostRouteSet struct { + ExactMatchRoutes RouteLookupLookup + PrefixMatchRoutes PrefixRouteSets + PrefixMatchRoutesLkp PrefixRouteSetLookup +} + +type HostRouteSetLookup map[string]*HostRouteSet diff --git a/pkg/router/router.go b/pkg/router/router.go index d9875c86b..487610c37 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -1,612 +1,46 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// package router provides an interface for routing HTTP requests to handlers package router import ( - "context" - "errors" - "fmt" "net/http" - "path" - "regexp" ) -// Router registers routes to be matched and dispatches a handler. -// -// It implements the http.Handler interface, so it can be registered to serve -// requests: -// -// var router = router.NewRouter() -// -// func main() { -// http.Handle("/", router) -// } -// -// Or, for Google App Engine, register it in a init() function: -// -// func init() { -// http.Handle("/", router) -// } -// -// This will send all incoming requests to the router. type Router interface { + // ServeHTTP services the provided HTTP Request and write the Response ServeHTTP(http.ResponseWriter, *http.Request) - PathPrefix(string) *Route - Handle(string, http.Handler) *Route - HandleFunc(string, func(http.ResponseWriter, *http.Request)) *Route + // RegisterRoute registers a handler for the provided path/host/method(s) + // If hosts is nil, the route uses global-routing instead of host-based + // If methods is nil, the route is applicable to GET and HEAD requests + // If methods includes GET but not HEAD, HEAD is automatically included + RegisterRoute(path string, hosts, methods []string, matchPrefix bool, + handler http.Handler) error + // Handler returns the handler matching the method/host/path in the Request + Handler(*http.Request) http.Handler + // SetMatchingScheme specifies the ways the Router matches requests + SetMatchingScheme(MatchingScheme) } -var ( - // ErrMethodMismatch is returned when the method in the request does not match - // the method defined against the route. - ErrMethodMismatch = errors.New("method is not allowed") - // ErrNotFound is returned when no route match is found. - ErrNotFound = errors.New("no matching route was found") -) - -// NewRouter returns a new router instance. -func NewRouter() Router { - return &router{namedRoutes: make(map[string]*Route)} -} - -type router struct { - // Configurable Handler to be used when no route matches. - NotFoundHandler http.Handler - - // Configurable Handler to be used when the request method does not match the route. - MethodNotAllowedHandler http.Handler - - // Routes to be matched, in order. - routes []*Route - - // Routes by name for URL building. - namedRoutes map[string]*Route - - // If true, do not clear the request context after handling the request. - // - // Deprecated: No effect, since the context is stored on the request itself. - KeepContext bool - - // Slice of middlewares to be called after a match is found - middlewares []middleware - - // configuration shared with `Route` - routeConf -} - -// common route configuration shared between `Router` and `Route` -type routeConf struct { - // If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" - useEncodedPath bool - - // If true, when the path pattern is "/path/", accessing "/path" will - // redirect to the former and vice versa. - strictSlash bool - - // If true, when the path pattern is "/path//to", accessing "/path//to" - // will not redirect - skipClean bool - - // Manager for the variables from host and path. - regexp routeRegexpGroup - - // List of matchers. - matchers []matcher - - // The scheme used when building URLs. - buildScheme string - - buildVarsFunc BuildVarsFunc -} - -// returns an effective deep copy of `routeConf` -func copyRouteConf(r routeConf) routeConf { - c := r - - if r.regexp.path != nil { - c.regexp.path = copyRouteRegexp(r.regexp.path) - } - - if r.regexp.host != nil { - c.regexp.host = copyRouteRegexp(r.regexp.host) - } - - c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries)) - for _, q := range r.regexp.queries { - c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q)) - } - - c.matchers = make([]matcher, len(r.matchers)) - copy(c.matchers, r.matchers) - - return c -} - -func copyRouteRegexp(r *routeRegexp) *routeRegexp { - c := *r - return &c -} - -// Match attempts to match the given request against the router's registered routes. -// -// If the request matches a route of this router or one of its subrouters the Route, -// Handler, and Vars fields of the the match argument are filled and this function -// returns true. -// -// If the request does not match any of this router's or its subrouters' routes -// then this function returns false. If available, a reason for the match failure -// will be filled in the match argument's MatchErr field. If the match failure type -// (eg: not found) has a registered handler, the handler is assigned to the Handler -// field of the match argument. -func (r *router) Match(req *http.Request, match *RouteMatch) bool { - for _, route := range r.routes { - if route.Match(req, match) { - // Build middleware chain if no error was found - if match.MatchErr == nil { - for i := len(r.middlewares) - 1; i >= 0; i-- { - match.Handler = r.middlewares[i].Middleware(match.Handler) - } - } - return true - } - } - - if match.MatchErr == ErrMethodMismatch { - if r.MethodNotAllowedHandler != nil { - match.Handler = r.MethodNotAllowedHandler - return true - } - - return false - } - - // Closest match for a router (includes sub-routers) - if r.NotFoundHandler != nil { - match.Handler = r.NotFoundHandler - match.MatchErr = ErrNotFound - return true - } - - match.MatchErr = ErrNotFound - return false -} - -// ServeHTTP dispatches the handler registered in the matched route. -// -// When there is a match, the route variables can be retrieved calling -// router.Vars(request). -func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if !r.skipClean { - path := req.URL.Path - if r.useEncodedPath { - path = req.URL.EscapedPath() - } - // Clean path to canonical form and redirect. - if p := cleanPath(path); p != path { - - // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. - // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: - // http://code.google.com/p/go/issues/detail?id=5252 - url := *req.URL - url.Path = p - p = url.String() - - w.Header().Set("Location", p) - w.WriteHeader(http.StatusMovedPermanently) - return - } - } - var match RouteMatch - var handler http.Handler - if r.Match(req, &match) { - handler = match.Handler - req = requestWithVars(req, match.Vars) - req = requestWithRoute(req, match.Route) - } - - if handler == nil && match.MatchErr == ErrMethodMismatch { - handler = methodNotAllowedHandler() - } - - if handler == nil { - handler = http.NotFoundHandler() - } - - handler.ServeHTTP(w, req) -} - -// Get returns a route registered with the given name. -func (r *router) Get(name string) *Route { - return r.namedRoutes[name] -} - -// GetRoute returns a route registered with the given name. This method -// was renamed to Get() and remains here for backwards compatibility. -func (r *router) GetRoute(name string) *Route { - return r.namedRoutes[name] -} - -// StrictSlash defines the trailing slash behavior for new routes. The initial -// value is false. -// -// When true, if the route path is "/path/", accessing "/path" will perform a redirect -// to the former and vice versa. In other words, your application will always -// see the path as specified in the route. -// -// When false, if the route path is "/path", accessing "/path/" will not match -// this route and vice versa. -// -// The re-direct is a HTTP 301 (Moved Permanently). Note that when this is set for -// routes with a non-idempotent method (e.g. POST, PUT), the subsequent re-directed -// request will be made as a GET by most clients. Use middleware or client settings -// to modify this behaviour as needed. -// -// Special case: when a route sets a path prefix using the PathPrefix() method, -// strict slash is ignored for that route because the redirect behavior can't -// be determined from a prefix alone. However, any subrouters created from that -// route inherit the original StrictSlash setting. -func (r *router) StrictSlash(value bool) http.Handler { - r.strictSlash = value - return r -} - -// SkipClean defines the path cleaning behaviour for new routes. The initial -// value is false. Users should be careful about which routes are not cleaned -// -// When true, if the route path is "/path//to", it will remain with the double -// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/ -// -// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will -// become /fetch/http/xkcd.com/534 -func (r *router) SkipClean(value bool) http.Handler { - r.skipClean = value - return r -} - -// UseEncodedPath tells the router to match the encoded original path -// to the routes. -// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". -// -// If not called, the router will match the unencoded path to the routes. -// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to" -func (r *router) UseEncodedPath() http.Handler { - r.useEncodedPath = true - return r -} - -// ---------------------------------------------------------------------------- -// Route factories -// ---------------------------------------------------------------------------- - -// NewRoute registers an empty route. -func (r *router) NewRoute() *Route { - // initialize a route with a copy of the parent router's configuration - route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes} - r.routes = append(r.routes, route) - return route -} - -// Name registers a new route with a name. -// See Route.Name(). -func (r *router) Name(name string) *Route { - return r.NewRoute().Name(name) -} - -// Handle registers a new route with a matcher for the URL path. -// See Route.Path() and Route.Handler(). -func (r *router) Handle(path string, handler http.Handler) *Route { - return r.NewRoute().Path(path).Handler(handler) -} - -// HandleFunc registers a new route with a matcher for the URL path. -// See Route.Path() and Route.HandlerFunc(). -func (r *router) HandleFunc(path string, f func(http.ResponseWriter, - *http.Request)) *Route { - return r.NewRoute().Path(path).HandlerFunc(f) -} - -// Headers registers a new route with a matcher for request header values. -// See Route.Headers(). -func (r *router) Headers(pairs ...string) *Route { - return r.NewRoute().Headers(pairs...) -} - -// Host registers a new route with a matcher for the URL host. -// See Route.Host(). -func (r *router) Host(tpl string) *Route { - return r.NewRoute().Host(tpl) -} - -// MatcherFunc registers a new route with a custom matcher function. -// See Route.MatcherFunc(). -func (r *router) MatcherFunc(f MatcherFunc) *Route { - return r.NewRoute().MatcherFunc(f) -} - -// Methods registers a new route with a matcher for HTTP methods. -// See Route.Methods(). -func (r *router) Methods(methods ...string) *Route { - return r.NewRoute().Methods(methods...) -} - -// Path registers a new route with a matcher for the URL path. -// See Route.Path(). -func (r *router) Path(tpl string) *Route { - return r.NewRoute().Path(tpl) -} - -// PathPrefix registers a new route with a matcher for the URL path prefix. -// See Route.PathPrefix(). -func (r *router) PathPrefix(tpl string) *Route { - return r.NewRoute().PathPrefix(tpl) -} - -// Queries registers a new route with a matcher for URL query values. -// See Route.Queries(). -func (r *router) Queries(pairs ...string) *Route { - return r.NewRoute().Queries(pairs...) -} - -// Schemes registers a new route with a matcher for URL schemes. -// See Route.Schemes(). -func (r *router) Schemes(schemes ...string) *Route { - return r.NewRoute().Schemes(schemes...) -} - -// BuildVarsFunc registers a new route with a custom function for modifying -// route variables before building a URL. -func (r *router) BuildVarsFunc(f BuildVarsFunc) *Route { - return r.NewRoute().BuildVarsFunc(f) -} - -// Walk walks the router and all its sub-routers, calling walkFn for each route -// in the tree. The routes are walked in the order they were added. Sub-routers -// are explored depth-first. -func (r *router) Walk(walkFn WalkFunc) error { - return r.walk(walkFn, []*Route{}) -} - -// ErrSkipRouter is used as a return value from WalkFuncs to indicate that the -// router that walk is about to descend down to should be skipped. -var ErrSkipRouter = errors.New("skip this router") - -// WalkFunc is the type of the function called for each route visited by Walk. -// At every invocation, it is given the current route, and the current router, -// and a list of ancestor routes that lead to the current route. -type WalkFunc func(route *Route, router *router, ancestors []*Route) error - -func (r *router) walk(walkFn WalkFunc, ancestors []*Route) error { - for _, t := range r.routes { - err := walkFn(t, r, ancestors) - if err == ErrSkipRouter { - continue - } - if err != nil { - return err - } - for _, sr := range t.matchers { - if h, ok := sr.(*router); ok { - ancestors = append(ancestors, t) - err := h.walk(walkFn, ancestors) - if err != nil { - return err - } - ancestors = ancestors[:len(ancestors)-1] - } - } - if h, ok := t.handler.(*router); ok { - ancestors = append(ancestors, t) - err := h.walk(walkFn, ancestors) - if err != nil { - return err - } - ancestors = ancestors[:len(ancestors)-1] - } - } - return nil -} - -// ---------------------------------------------------------------------------- -// Context -// ---------------------------------------------------------------------------- - -// RouteMatch stores information about a matched route. -type RouteMatch struct { - Route *Route - Handler http.Handler - Vars map[string]string - - // MatchErr is set to appropriate matching error - // It is set to ErrMethodMismatch if there is a mismatch in - // the request method and route method - MatchErr error -} - -type contextKey int +type MatchingScheme int const ( - varsKey contextKey = iota - routeKey -) + MatchHostname MatchingScheme = 1 + MatchPathPrefix MatchingScheme = 2 -// Vars returns the route variables for the current request, if any. -func Vars(r *http.Request) map[string]string { - if rv := r.Context().Value(varsKey); rv != nil { - return rv.(map[string]string) - } - return nil -} - -// CurrentRoute returns the matched route for the current request, if any. -// This only works when called inside the handler of the matched route -// because the matched route is stored in the request context which is cleared -// after the handler returns. -func CurrentRoute(r *http.Request) *Route { - if rv := r.Context().Value(routeKey); rv != nil { - return rv.(*Route) - } - return nil -} - -func requestWithVars(r *http.Request, vars map[string]string) *http.Request { - ctx := context.WithValue(r.Context(), varsKey, vars) - return r.WithContext(ctx) -} - -func requestWithRoute(r *http.Request, route *Route) *http.Request { - ctx := context.WithValue(r.Context(), routeKey, route) - return r.WithContext(ctx) -} - -// ---------------------------------------------------------------------------- -// Helpers -// ---------------------------------------------------------------------------- - -// cleanPath returns the canonical path for p, eliminating . and .. elements. -// Borrowed from the net/http package. -func cleanPath(p string) string { - if p == "" { - return "/" - } - if p[0] != '/' { - p = "/" + p - } - np := path.Clean(p) - // path.Clean removes trailing slash except for root; - // put the trailing slash back if necessary. - if p[len(p)-1] == '/' && np != "/" { - np += "/" - } - - return np -} - -// uniqueVars returns an error if two slices contain duplicated strings. -func uniqueVars(s1, s2 []string) error { - for _, v1 := range s1 { - for _, v2 := range s2 { - if v1 == v2 { - return fmt.Errorf("router: duplicated route variable %q", v2) - } - } - } - return nil -} - -// checkPairs returns the count of strings passed in, and an error if -// the count is not an even number. -func checkPairs(pairs ...string) (int, error) { - length := len(pairs) - if length%2 != 0 { - return length, fmt.Errorf( - "router: number of parameters must be multiple of 2, got %v", pairs) - } - return length, nil -} - -// mapFromPairsToString converts variadic string parameters to a -// string to string map. -func mapFromPairsToString(pairs ...string) (map[string]string, error) { - length, err := checkPairs(pairs...) - if err != nil { - return nil, err - } - m := make(map[string]string, length/2) - for i := 0; i < length; i += 2 { - m[pairs[i]] = pairs[i+1] - } - return m, nil -} - -// mapFromPairsToRegex converts variadic string parameters to a -// string to regex map. -func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) { - length, err := checkPairs(pairs...) - if err != nil { - return nil, err - } - m := make(map[string]*regexp.Regexp, length/2) - for i := 0; i < length; i += 2 { - regex, err := regexp.Compile(pairs[i+1]) - if err != nil { - return nil, err - } - m[pairs[i]] = regex - } - return m, nil -} - -// matchInArray returns true if the given string value is in the array. -func matchInArray(arr []string, value string) bool { - for _, v := range arr { - if v == value { - return true - } - } - return false -} - -// matchMapWithString returns true if the given key/value pairs exist in a given map. -func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool { - for k, v := range toCheck { - // Check if key exists. - if canonicalKey { - k = http.CanonicalHeaderKey(k) - } - if values := toMatch[k]; values == nil { - return false - } else if v != "" { - // If value was defined as an empty string we only check that the - // key exists. Otherwise we also check for equality. - valueExists := false - for _, value := range values { - if v == value { - valueExists = true - break - } - } - if !valueExists { - return false - } - } - } - return true -} - -// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against -// the given regex -func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool { - for k, v := range toCheck { - // Check if key exists. - if canonicalKey { - k = http.CanonicalHeaderKey(k) - } - if values := toMatch[k]; values == nil { - return false - } else if v != nil { - // If value was defined as an empty string we only check that the - // key exists. Otherwise we also check for equality. - valueExists := false - for _, value := range values { - if v.MatchString(value) { - valueExists = true - break - } - } - if !valueExists { - return false - } - } - } - return true -} - -// methodNotAllowed replies to the request with an HTTP status code 405. -func methodNotAllowed(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) -} - -// methodNotAllowedHandler returns a simple request handler -// that replies to each request with a status code 405. -func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) } + DefaultMatchingScheme MatchingScheme = MatchHostname | MatchPathPrefix +) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go deleted file mode 100644 index c782a4552..000000000 --- a/pkg/router/router_test.go +++ /dev/null @@ -1,1590 +0,0 @@ -// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. -// https://github.com/gorilla/mux/blob/master/LICENSE -// Gorilla Mux was archived in December 2022--this is a duplicate of its source to use in Trickster. -package router - -import ( - "bufio" - "bytes" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "reflect" - "strings" - "testing" -) - -func (r *Route) GoString() string { - matchers := make([]string, len(r.matchers)) - for i, m := range r.matchers { - matchers[i] = fmt.Sprintf("%#v", m) - } - return fmt.Sprintf("&Route{matchers:[]matcher{%s}}", strings.Join(matchers, ", ")) -} - -func (r *routeRegexp) GoString() string { - return fmt.Sprintf("&routeRegexp{template: %q, regexpType: %v, options: %v, regexp: regexp.MustCompile(%q), reverse: %q, varsN: %v, varsR: %v", r.template, r.regexpType, r.options, r.regexp.String(), r.reverse, r.varsN, r.varsR) -} - -type routeTest struct { - title string // title of the test - route *Route // the route being tested - request *http.Request // a request to test the route - vars map[string]string // the expected vars of the match - scheme string // the expected scheme of the built URL - host string // the expected host of the built URL - path string // the expected path of the built URL - query string // the expected query string of the built URL - pathTemplate string // the expected path template of the route - hostTemplate string // the expected host template of the route - queriesTemplate string // the expected query template of the route - methods []string // the expected route methods - pathRegexp string // the expected path regexp - queriesRegexp string // the expected query regexp - shouldMatch bool // whether the request is expected to match the route at all - shouldRedirect bool // whether the request should result in a redirect -} - -func TestHost(t *testing.T) { - - tests := []routeTest{ - { - title: "Host route match", - route: new(Route).Host("aaa.bbb.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{}, - host: "aaa.bbb.ccc", - path: "", - shouldMatch: true, - }, - { - title: "Host route, wrong host in request URL", - route: new(Route).Host("aaa.bbb.ccc"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{}, - host: "aaa.bbb.ccc", - path: "", - shouldMatch: false, - }, - { - title: "Host route with port, match", - route: new(Route).Host("aaa.bbb.ccc:1234"), - request: newRequest("GET", "http://aaa.bbb.ccc:1234/111/222/333"), - vars: map[string]string{}, - host: "aaa.bbb.ccc:1234", - path: "", - shouldMatch: true, - }, - { - title: "Host route with port, wrong port in request URL", - route: new(Route).Host("aaa.bbb.ccc:1234"), - request: newRequest("GET", "http://aaa.bbb.ccc:9999/111/222/333"), - vars: map[string]string{}, - host: "aaa.bbb.ccc:1234", - path: "", - shouldMatch: false, - }, - { - title: "Host route, match with host in request header", - route: new(Route).Host("aaa.bbb.ccc"), - request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc"), - vars: map[string]string{}, - host: "aaa.bbb.ccc", - path: "", - shouldMatch: true, - }, - { - title: "Host route, wrong host in request header", - route: new(Route).Host("aaa.bbb.ccc"), - request: newRequestHost("GET", "/111/222/333", "aaa.222.ccc"), - vars: map[string]string{}, - host: "aaa.bbb.ccc", - path: "", - shouldMatch: false, - }, - { - title: "Host route with port, match with request header", - route: new(Route).Host("aaa.bbb.ccc:1234"), - request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"), - vars: map[string]string{}, - host: "aaa.bbb.ccc:1234", - path: "", - shouldMatch: true, - }, - { - title: "Host route with port, wrong host in request header", - route: new(Route).Host("aaa.bbb.ccc:1234"), - request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:9999"), - vars: map[string]string{}, - host: "aaa.bbb.ccc:1234", - path: "", - shouldMatch: false, - }, - { - title: "Host route with pattern, match with request header", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}"), - request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:123"), - vars: map[string]string{"v1": "bbb", "v2": "23"}, - host: "aaa.bbb.ccc:123", - path: "", - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}`, - shouldMatch: true, - }, - { - title: "Host route with pattern, match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with pattern, additional capturing group, match", - route: new(Route).Host("aaa.{v1:[a-z]{2}(?:b|c)}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `aaa.{v1:[a-z]{2}(?:b|c)}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with pattern, wrong host in request URL", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: false, - }, - { - title: "Host route with multiple patterns, match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: true, - }, - { - title: "Host route with multiple patterns, wrong host in request URL", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: false, - }, - { - title: "Host route with hyphenated name and pattern, match", - route: new(Route).Host("aaa.{v-1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `aaa.{v-1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with hyphenated name and pattern, additional capturing group, match", - route: new(Route).Host("aaa.{v-1:[a-z]{2}(?:b|c)}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `aaa.{v-1:[a-z]{2}(?:b|c)}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with multiple hyphenated names and patterns, match", - route: new(Route).Host("{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "aaa", "v-2": "bbb", "v-3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - hostTemplate: `{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}`, - shouldMatch: true, - }, - } - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - }) - } -} - -func TestPath(t *testing.T) { - tests := []routeTest{ - { - title: "Path route, match", - route: new(Route).Path("/111/222/333"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{}, - host: "", - path: "/111/222/333", - shouldMatch: true, - }, - { - title: "Path route, match with trailing slash in request and path", - route: new(Route).Path("/111/"), - request: newRequest("GET", "http://localhost/111/"), - vars: map[string]string{}, - host: "", - path: "/111/", - shouldMatch: true, - }, - { - title: "Path route, do not match with trailing slash in path", - route: new(Route).Path("/111/"), - request: newRequest("GET", "http://localhost/111"), - vars: map[string]string{}, - host: "", - path: "/111", - pathTemplate: `/111/`, - pathRegexp: `^/111/$`, - shouldMatch: false, - }, - { - title: "Path route, do not match with trailing slash in request", - route: new(Route).Path("/111"), - request: newRequest("GET", "http://localhost/111/"), - vars: map[string]string{}, - host: "", - path: "/111/", - pathTemplate: `/111`, - shouldMatch: false, - }, - { - title: "Path route, match root with no host", - route: new(Route).Path("/"), - request: newRequest("GET", "/"), - vars: map[string]string{}, - host: "", - path: "/", - pathTemplate: `/`, - pathRegexp: `^/$`, - shouldMatch: true, - }, - { - title: "Path route, match root with no host, App Engine format", - route: new(Route).Path("/"), - request: func() *http.Request { - r := newRequest("GET", "http://localhost/") - r.RequestURI = "/" - return r - }(), - vars: map[string]string{}, - host: "", - path: "/", - pathTemplate: `/`, - shouldMatch: true, - }, - { - title: "Path route, wrong path in request in request URL", - route: new(Route).Path("/111/222/333"), - request: newRequest("GET", "http://localhost/1/2/3"), - vars: map[string]string{}, - host: "", - path: "/111/222/333", - shouldMatch: false, - }, - { - title: "Path route with pattern, match", - route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222/333", - pathTemplate: `/111/{v1:[0-9]{3}}/333`, - shouldMatch: true, - }, - { - title: "Path route with pattern, URL in request does not match", - route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222/333", - pathTemplate: `/111/{v1:[0-9]{3}}/333`, - pathRegexp: `^/111/(?P[0-9]{3})/333$`, - shouldMatch: false, - }, - { - title: "Path route with multiple patterns, match", - route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, - host: "", - path: "/111/222/333", - pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, - pathRegexp: `^/(?P[0-9]{3})/(?P[0-9]{3})/(?P[0-9]{3})$`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns, URL in request does not match", - route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, - host: "", - path: "/111/222/333", - pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, - pathRegexp: `^/(?P[0-9]{3})/(?P[0-9]{3})/(?P[0-9]{3})$`, - shouldMatch: false, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|(?:b/c)}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, - host: "", - path: "/a/product_name/1", - pathTemplate: `/{category:a|(?:b/c)}/{product}/{id:[0-9]+}`, - pathRegexp: `^/(?Pa|(?:b/c))/(?P[^/]+)/(?P[0-9]+)$`, - shouldMatch: true, - }, - { - title: "Path route with hyphenated name and pattern, match", - route: new(Route).Path("/111/{v-1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v-1": "222"}, - host: "", - path: "/111/222/333", - pathTemplate: `/111/{v-1:[0-9]{3}}/333`, - pathRegexp: `^/111/(?P[0-9]{3})/333$`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns, match", - route: new(Route).Path("/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v-1": "111", "v-2": "222", "v-3": "333"}, - host: "", - path: "/111/222/333", - pathTemplate: `/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}`, - pathRegexp: `^/(?P[0-9]{3})/(?P[0-9]{3})/(?P[0-9]{3})$`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns with pipe, match", - route: new(Route).Path("/{product-category:a|(?:b/c)}/{product-name}/{product-id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"product-category": "a", "product-name": "product_name", "product-id": "1"}, - host: "", - path: "/a/product_name/1", - pathTemplate: `/{product-category:a|(?:b/c)}/{product-name}/{product-id:[0-9]+}`, - pathRegexp: `^/(?Pa|(?:b/c))/(?P[^/]+)/(?P[0-9]+)$`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns with pipe and case insensitive, match", - route: new(Route).Path("/{type:(?i:daily|mini|variety)}-{date:\\d{4,4}-\\d{2,2}-\\d{2,2}}"), - request: newRequest("GET", "http://localhost/daily-2016-01-01"), - vars: map[string]string{"type": "daily", "date": "2016-01-01"}, - host: "", - path: "/daily-2016-01-01", - pathTemplate: `/{type:(?i:daily|mini|variety)}-{date:\d{4,4}-\d{2,2}-\d{2,2}}`, - pathRegexp: `^/(?P(?i:daily|mini|variety))-(?P\d{4,4}-\d{2,2}-\d{2,2})$`, - shouldMatch: true, - }, - { - title: "Path route with empty match right after other match", - route: new(Route).Path(`/{v1:[0-9]*}{v2:[a-z]*}/{v3:[0-9]*}`), - request: newRequest("GET", "http://localhost/111/222"), - vars: map[string]string{"v1": "111", "v2": "", "v3": "222"}, - host: "", - path: "/111/222", - pathTemplate: `/{v1:[0-9]*}{v2:[a-z]*}/{v3:[0-9]*}`, - pathRegexp: `^/(?P[0-9]*)(?P[a-z]*)/(?P[0-9]*)$`, - shouldMatch: true, - }, - { - title: "Path route with single pattern with pipe, match", - route: new(Route).Path("/{category:a|b/c}"), - request: newRequest("GET", "http://localhost/a"), - vars: map[string]string{"category": "a"}, - host: "", - path: "/a", - pathTemplate: `/{category:a|b/c}`, - shouldMatch: true, - }, - { - title: "Path route with single pattern with pipe, match", - route: new(Route).Path("/{category:a|b/c}"), - request: newRequest("GET", "http://localhost/b/c"), - vars: map[string]string{"category": "b/c"}, - host: "", - path: "/b/c", - pathTemplate: `/{category:a|b/c}`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, - host: "", - path: "/a/product_name/1", - pathTemplate: `/{category:a|b/c}/{product}/{id:[0-9]+}`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/b/c/product_name/1"), - vars: map[string]string{"category": "b/c", "product": "product_name", "id": "1"}, - host: "", - path: "/b/c/product_name/1", - pathTemplate: `/{category:a|b/c}/{product}/{id:[0-9]+}`, - shouldMatch: true, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - testUseEscapedRoute(t, test) - testRegexp(t, test) - }) - } -} - -func TestPathPrefix(t *testing.T) { - tests := []routeTest{ - { - title: "PathPrefix route, match", - route: new(Route).PathPrefix("/111"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{}, - host: "", - path: "/111", - shouldMatch: true, - }, - { - title: "PathPrefix route, match substring", - route: new(Route).PathPrefix("/1"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{}, - host: "", - path: "/1", - shouldMatch: true, - }, - { - title: "PathPrefix route, URL prefix in request does not match", - route: new(Route).PathPrefix("/111"), - request: newRequest("GET", "http://localhost/1/2/3"), - vars: map[string]string{}, - host: "", - path: "/111", - shouldMatch: false, - }, - { - title: "PathPrefix route with pattern, match", - route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222", - pathTemplate: `/111/{v1:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "PathPrefix route with pattern, URL prefix in request does not match", - route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222", - pathTemplate: `/111/{v1:[0-9]{3}}`, - shouldMatch: false, - }, - { - title: "PathPrefix route with multiple patterns, match", - route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "111", "v2": "222"}, - host: "", - path: "/111/222", - pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "PathPrefix route with multiple patterns, URL prefix in request does not match", - route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "111", "v2": "222"}, - host: "", - path: "/111/222", - pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, - shouldMatch: false, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - testUseEscapedRoute(t, test) - }) - } -} - -func TestSchemeHostPath(t *testing.T) { - tests := []routeTest{ - { - title: "Host and Path route, match", - route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/222/333`, - hostTemplate: `aaa.bbb.ccc`, - shouldMatch: true, - }, - { - title: "Scheme, Host, and Path route, match", - route: new(Route).Schemes("https").Host("aaa.bbb.ccc").Path("/111/222/333"), - request: newRequest("GET", "https://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{}, - scheme: "https", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/222/333`, - hostTemplate: `aaa.bbb.ccc`, - shouldMatch: true, - }, - { - title: "Host and Path route, wrong host in request URL", - route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/222/333`, - hostTemplate: `aaa.bbb.ccc`, - shouldMatch: false, - }, - { - title: "Host and Path route with pattern, match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb", "v2": "222"}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/{v2:[0-9]{3}}/333`, - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Scheme, Host, and Path route with host and path patterns, match", - route: new(Route).Schemes("ftp", "ssss").Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), - request: newRequest("GET", "ssss://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb", "v2": "222"}, - scheme: "ftp", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/{v2:[0-9]{3}}/333`, - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host and Path route with pattern, URL in request does not match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb", "v2": "222"}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/111/{v2:[0-9]{3}}/333`, - hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: false, - }, - { - title: "Host and Path route with multiple patterns, match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, - hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: true, - }, - { - title: "Host and Path route with multiple patterns, URL in request does not match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, - scheme: "http", - host: "aaa.bbb.ccc", - path: "/111/222/333", - pathTemplate: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, - hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: false, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - testUseEscapedRoute(t, test) - }) - } -} - -func TestHeaders(t *testing.T) { - // newRequestHeaders creates a new request with a method, url, and headers - newRequestHeaders := func(method, url string, headers map[string]string) *http.Request { - req, err := http.NewRequest(method, url, nil) - if err != nil { - panic(err) - } - for k, v := range headers { - req.Header.Add(k, v) - } - return req - } - - tests := []routeTest{ - { - title: "Headers route, match", - route: new(Route).Headers("foo", "bar", "baz", "ding"), - request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "bar", "baz": "ding"}), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: true, - }, - { - title: "Headers route, bad header values", - route: new(Route).Headers("foo", "bar", "baz", "ding"), - request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "bar", "baz": "dong"}), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: false, - }, - { - title: "Headers route, regex header values to match", - route: new(Route).HeadersRegexp("foo", "ba[zr]"), - request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "baw"}), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: false, - }, - { - title: "Headers route, regex header values to match", - route: new(Route).HeadersRegexp("foo", "ba[zr]"), - request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "baz"}), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: true, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - }) - } -} - -func TestMethods(t *testing.T) { - tests := []routeTest{ - { - title: "Methods route, match GET", - route: new(Route).Methods("GET", "POST"), - request: newRequest("GET", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - methods: []string{"GET", "POST"}, - shouldMatch: true, - }, - { - title: "Methods route, match POST", - route: new(Route).Methods("GET", "POST"), - request: newRequest("POST", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - methods: []string{"GET", "POST"}, - shouldMatch: true, - }, - { - title: "Methods route, bad method", - route: new(Route).Methods("GET", "POST"), - request: newRequest("PUT", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - methods: []string{"GET", "POST"}, - shouldMatch: false, - }, - { - title: "Route without methods", - route: new(Route), - request: newRequest("PUT", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - methods: []string{}, - shouldMatch: true, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - testMethods(t, test) - }) - } -} - -func TestQueries(t *testing.T) { - tests := []routeTest{ - { - title: "Queries route, match", - route: new(Route).Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://localhost?foo=bar&baz=ding"), - vars: map[string]string{}, - host: "", - path: "", - query: "foo=bar&baz=ding", - queriesTemplate: "foo=bar,baz=ding", - queriesRegexp: "^foo=bar$,^baz=ding$", - shouldMatch: true, - }, - { - title: "Queries route, match with a query string", - route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://www.example.com/api?foo=bar&baz=ding"), - vars: map[string]string{}, - host: "", - path: "", - query: "foo=bar&baz=ding", - pathTemplate: `/api`, - hostTemplate: `www.example.com`, - queriesTemplate: "foo=bar,baz=ding", - queriesRegexp: "^foo=bar$,^baz=ding$", - shouldMatch: true, - }, - { - title: "Queries route, match with a query string out of order", - route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"), - vars: map[string]string{}, - host: "", - path: "", - query: "foo=bar&baz=ding", - pathTemplate: `/api`, - hostTemplate: `www.example.com`, - queriesTemplate: "foo=bar,baz=ding", - queriesRegexp: "^foo=bar$,^baz=ding$", - shouldMatch: true, - }, - { - title: "Queries route, bad query", - route: new(Route).Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://localhost?foo=bar&baz=dong"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo=bar,baz=ding", - queriesRegexp: "^foo=bar$,^baz=ding$", - shouldMatch: false, - }, - { - title: "Queries route with pattern, match", - route: new(Route).Queries("foo", "{v1}"), - request: newRequest("GET", "http://localhost?foo=bar"), - vars: map[string]string{"v1": "bar"}, - host: "", - path: "", - query: "foo=bar", - queriesTemplate: "foo={v1}", - queriesRegexp: "^foo=(?P.*)$", - shouldMatch: true, - }, - { - title: "Queries route with multiple patterns, match", - route: new(Route).Queries("foo", "{v1}", "baz", "{v2}"), - request: newRequest("GET", "http://localhost?foo=bar&baz=ding"), - vars: map[string]string{"v1": "bar", "v2": "ding"}, - host: "", - path: "", - query: "foo=bar&baz=ding", - queriesTemplate: "foo={v1},baz={v2}", - queriesRegexp: "^foo=(?P.*)$,^baz=(?P.*)$", - shouldMatch: true, - }, - { - title: "Queries route with regexp pattern, match", - route: new(Route).Queries("foo", "{v1:[0-9]+}"), - request: newRequest("GET", "http://localhost?foo=10"), - vars: map[string]string{"v1": "10"}, - host: "", - path: "", - query: "foo=10", - queriesTemplate: "foo={v1:[0-9]+}", - queriesRegexp: "^foo=(?P[0-9]+)$", - shouldMatch: true, - }, - { - title: "Queries route with regexp pattern, regexp does not match", - route: new(Route).Queries("foo", "{v1:[0-9]+}"), - request: newRequest("GET", "http://localhost?foo=a"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo={v1:[0-9]+}", - queriesRegexp: "^foo=(?P[0-9]+)$", - shouldMatch: false, - }, - { - title: "Queries route with regexp pattern with quantifier, match", - route: new(Route).Queries("foo", "{v1:[0-9]{1}}"), - request: newRequest("GET", "http://localhost?foo=1"), - vars: map[string]string{"v1": "1"}, - host: "", - path: "", - query: "foo=1", - queriesTemplate: "foo={v1:[0-9]{1}}", - queriesRegexp: "^foo=(?P[0-9]{1})$", - shouldMatch: true, - }, - { - title: "Queries route with regexp pattern with quantifier, additional variable in query string, match", - route: new(Route).Queries("foo", "{v1:[0-9]{1}}"), - request: newRequest("GET", "http://localhost?bar=2&foo=1"), - vars: map[string]string{"v1": "1"}, - host: "", - path: "", - query: "foo=1", - queriesTemplate: "foo={v1:[0-9]{1}}", - queriesRegexp: "^foo=(?P[0-9]{1})$", - shouldMatch: true, - }, - { - title: "Queries route with regexp pattern with quantifier, regexp does not match", - route: new(Route).Queries("foo", "{v1:[0-9]{1}}"), - request: newRequest("GET", "http://localhost?foo=12"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo={v1:[0-9]{1}}", - queriesRegexp: "^foo=(?P[0-9]{1})$", - shouldMatch: false, - }, - { - title: "Queries route with regexp pattern with quantifier, additional capturing group", - route: new(Route).Queries("foo", "{v1:[0-9]{1}(?:a|b)}"), - request: newRequest("GET", "http://localhost?foo=1a"), - vars: map[string]string{"v1": "1a"}, - host: "", - path: "", - query: "foo=1a", - queriesTemplate: "foo={v1:[0-9]{1}(?:a|b)}", - queriesRegexp: "^foo=(?P[0-9]{1}(?:a|b))$", - shouldMatch: true, - }, - { - title: "Queries route with regexp pattern with quantifier, additional variable in query string, regexp does not match", - route: new(Route).Queries("foo", "{v1:[0-9]{1}}"), - request: newRequest("GET", "http://localhost?foo=12"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo={v1:[0-9]{1}}", - queriesRegexp: "^foo=(?P[0-9]{1})$", - shouldMatch: false, - }, - { - title: "Queries route with hyphenated name, match", - route: new(Route).Queries("foo", "{v-1}"), - request: newRequest("GET", "http://localhost?foo=bar"), - vars: map[string]string{"v-1": "bar"}, - host: "", - path: "", - query: "foo=bar", - queriesTemplate: "foo={v-1}", - queriesRegexp: "^foo=(?P.*)$", - shouldMatch: true, - }, - { - title: "Queries route with multiple hyphenated names, match", - route: new(Route).Queries("foo", "{v-1}", "baz", "{v-2}"), - request: newRequest("GET", "http://localhost?foo=bar&baz=ding"), - vars: map[string]string{"v-1": "bar", "v-2": "ding"}, - host: "", - path: "", - query: "foo=bar&baz=ding", - queriesTemplate: "foo={v-1},baz={v-2}", - queriesRegexp: "^foo=(?P.*)$,^baz=(?P.*)$", - shouldMatch: true, - }, - { - title: "Queries route with hyphenate name and pattern, match", - route: new(Route).Queries("foo", "{v-1:[0-9]+}"), - request: newRequest("GET", "http://localhost?foo=10"), - vars: map[string]string{"v-1": "10"}, - host: "", - path: "", - query: "foo=10", - queriesTemplate: "foo={v-1:[0-9]+}", - queriesRegexp: "^foo=(?P[0-9]+)$", - shouldMatch: true, - }, - { - title: "Queries route with hyphenated name and pattern with quantifier, additional capturing group", - route: new(Route).Queries("foo", "{v-1:[0-9]{1}(?:a|b)}"), - request: newRequest("GET", "http://localhost?foo=1a"), - vars: map[string]string{"v-1": "1a"}, - host: "", - path: "", - query: "foo=1a", - queriesTemplate: "foo={v-1:[0-9]{1}(?:a|b)}", - queriesRegexp: "^foo=(?P[0-9]{1}(?:a|b))$", - shouldMatch: true, - }, - { - title: "Queries route with empty value, should match", - route: new(Route).Queries("foo", ""), - request: newRequest("GET", "http://localhost?foo=bar"), - vars: map[string]string{}, - host: "", - path: "", - query: "foo=", - queriesTemplate: "foo=", - queriesRegexp: "^foo=.*$", - shouldMatch: true, - }, - { - title: "Queries route with empty value and no parameter in request, should not match", - route: new(Route).Queries("foo", ""), - request: newRequest("GET", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo=", - queriesRegexp: "^foo=.*$", - shouldMatch: false, - }, - { - title: "Queries route with empty value and empty parameter in request, should match", - route: new(Route).Queries("foo", ""), - request: newRequest("GET", "http://localhost?foo="), - vars: map[string]string{}, - host: "", - path: "", - query: "foo=", - queriesTemplate: "foo=", - queriesRegexp: "^foo=.*$", - shouldMatch: true, - }, - { - title: "Queries route with overlapping value, should not match", - route: new(Route).Queries("foo", "bar"), - request: newRequest("GET", "http://localhost?foo=barfoo"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo=bar", - queriesRegexp: "^foo=bar$", - shouldMatch: false, - }, - { - title: "Queries route with no parameter in request, should not match", - route: new(Route).Queries("foo", "{bar}"), - request: newRequest("GET", "http://localhost"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo={bar}", - queriesRegexp: "^foo=(?P.*)$", - shouldMatch: false, - }, - { - title: "Queries route with empty parameter in request, should match", - route: new(Route).Queries("foo", "{bar}"), - request: newRequest("GET", "http://localhost?foo="), - vars: map[string]string{"foo": ""}, - host: "", - path: "", - query: "foo=", - queriesTemplate: "foo={bar}", - queriesRegexp: "^foo=(?P.*)$", - shouldMatch: true, - }, - { - title: "Queries route, bad submatch", - route: new(Route).Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://localhost?fffoo=bar&baz=dingggg"), - vars: map[string]string{}, - host: "", - path: "", - queriesTemplate: "foo=bar,baz=ding", - queriesRegexp: "^foo=bar$,^baz=ding$", - shouldMatch: false, - }, - { - title: "Queries route with pattern, match, escaped value", - route: new(Route).Queries("foo", "{v1}"), - request: newRequest("GET", "http://localhost?foo=%25bar%26%20%2F%3D%3F"), - vars: map[string]string{"v1": "%bar& /=?"}, - host: "", - path: "", - query: "foo=%25bar%26+%2F%3D%3F", - queriesTemplate: "foo={v1}", - queriesRegexp: "^foo=(?P.*)$", - shouldMatch: true, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testTemplate(t, test) - testQueriesTemplates(t, test) - testUseEscapedRoute(t, test) - testQueriesRegexp(t, test) - }) - } -} - -func TestSchemes(t *testing.T) { - tests := []routeTest{ - // Schemes - { - title: "Schemes route, default scheme, match http, build http", - route: new(Route).Host("localhost"), - request: newRequest("GET", "http://localhost"), - scheme: "http", - host: "localhost", - shouldMatch: true, - }, - { - title: "Schemes route, match https, build https", - route: new(Route).Schemes("https", "ftp").Host("localhost"), - request: newRequest("GET", "https://localhost"), - scheme: "https", - host: "localhost", - shouldMatch: true, - }, - { - title: "Schemes route, match ftp, build https", - route: new(Route).Schemes("https", "ftp").Host("localhost"), - request: newRequest("GET", "ftp://localhost"), - scheme: "https", - host: "localhost", - shouldMatch: true, - }, - { - title: "Schemes route, match ftp, build ftp", - route: new(Route).Schemes("ftp", "https").Host("localhost"), - request: newRequest("GET", "ftp://localhost"), - scheme: "ftp", - host: "localhost", - shouldMatch: true, - }, - { - title: "Schemes route, bad scheme", - route: new(Route).Schemes("https", "ftp").Host("localhost"), - request: newRequest("GET", "http://localhost"), - scheme: "https", - host: "localhost", - shouldMatch: false, - }, - } - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - }) - } -} - -func TestMatcherFunc(t *testing.T) { - m := func(r *http.Request, m *RouteMatch) bool { - return r.URL.Host == "aaa.bbb.ccc" - } - - tests := []routeTest{ - { - title: "MatchFunc route, match", - route: new(Route).MatcherFunc(m), - request: newRequest("GET", "http://aaa.bbb.ccc"), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: true, - }, - { - title: "MatchFunc route, non-match", - route: new(Route).MatcherFunc(m), - request: newRequest("GET", "http://aaa.222.ccc"), - vars: map[string]string{}, - host: "", - path: "", - shouldMatch: false, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - }) - } -} - -func TestBuildVarsFunc(t *testing.T) { - tests := []routeTest{ - { - title: "BuildVarsFunc set on route", - route: new(Route).Path(`/111/{v1:\d}{v2:.*}`).BuildVarsFunc(func(vars map[string]string) map[string]string { - vars["v1"] = "3" - vars["v2"] = "a" - return vars - }), - request: newRequest("GET", "http://localhost/111/2"), - path: "/111/3a", - pathTemplate: `/111/{v1:\d}{v2:.*}`, - shouldMatch: true, - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - testRoute(t, test) - testTemplate(t, test) - }) - } -} - -// ---------------------------------------------------------------------------- -// Helpers -// ---------------------------------------------------------------------------- - -func getRouteTemplate(route *Route) string { - host, err := route.GetHostTemplate() - if err != nil { - host = "none" - } - path, err := route.GetPathTemplate() - if err != nil { - path = "none" - } - return fmt.Sprintf("Host: %v, Path: %v", host, path) -} - -func testRoute(t *testing.T, test routeTest) { - request := test.request - route := test.route - vars := test.vars - shouldMatch := test.shouldMatch - query := test.query - shouldRedirect := test.shouldRedirect - uri := url.URL{ - Scheme: test.scheme, - Host: test.host, - Path: test.path, - } - if uri.Scheme == "" { - uri.Scheme = "http" - } - - var match RouteMatch - ok := route.Match(request, &match) - if ok != shouldMatch { - msg := "Should match" - if !shouldMatch { - msg = "Should not match" - } - t.Errorf("(%v) %v:\nRoute: %#v\nRequest: %#v\nVars: %v\n", test.title, msg, route, request, vars) - return - } - if shouldMatch { - if vars != nil && !stringMapEqual(vars, match.Vars) { - t.Errorf("(%v) Vars not equal: expected %v, got %v", test.title, vars, match.Vars) - return - } - if test.scheme != "" { - u, err := route.URL(mapToPairs(match.Vars)...) - if err != nil { - t.Fatalf("(%v) URL error: %v -- %v", test.title, err, getRouteTemplate(route)) - } - if uri.Scheme != u.Scheme { - t.Errorf("(%v) URLScheme not equal: expected %v, got %v", test.title, uri.Scheme, u.Scheme) - return - } - } - if test.host != "" { - u, err := test.route.URLHost(mapToPairs(match.Vars)...) - if err != nil { - t.Fatalf("(%v) URLHost error: %v -- %v", test.title, err, getRouteTemplate(route)) - } - if uri.Scheme != u.Scheme { - t.Errorf("(%v) URLHost scheme not equal: expected %v, got %v -- %v", test.title, uri.Scheme, u.Scheme, getRouteTemplate(route)) - return - } - if uri.Host != u.Host { - t.Errorf("(%v) URLHost host not equal: expected %v, got %v -- %v", test.title, uri.Host, u.Host, getRouteTemplate(route)) - return - } - } - if test.path != "" { - u, err := route.URLPath(mapToPairs(match.Vars)...) - if err != nil { - t.Fatalf("(%v) URLPath error: %v -- %v", test.title, err, getRouteTemplate(route)) - } - if uri.Path != u.Path { - t.Errorf("(%v) URLPath not equal: expected %v, got %v -- %v", test.title, uri.Path, u.Path, getRouteTemplate(route)) - return - } - } - if test.host != "" && test.path != "" { - u, err := route.URL(mapToPairs(match.Vars)...) - if err != nil { - t.Fatalf("(%v) URL error: %v -- %v", test.title, err, getRouteTemplate(route)) - } - if expected, got := uri.String(), u.String(); expected != got { - t.Errorf("(%v) URL not equal: expected %v, got %v -- %v", test.title, expected, got, getRouteTemplate(route)) - return - } - } - if query != "" { - u, err := route.URL(mapToPairs(match.Vars)...) - if err != nil { - t.Errorf("(%v) erred while creating url: %v", test.title, err) - return - } - if query != u.RawQuery { - t.Errorf("(%v) URL query not equal: expected %v, got %v", test.title, query, u.RawQuery) - return - } - } - if shouldRedirect && match.Handler == nil { - t.Errorf("(%v) Did not redirect", test.title) - return - } - if !shouldRedirect && match.Handler != nil { - t.Errorf("(%v) Unexpected redirect", test.title) - return - } - } -} - -func testUseEscapedRoute(t *testing.T, test routeTest) { - test.route.useEncodedPath = true - testRoute(t, test) -} - -func testTemplate(t *testing.T, test routeTest) { - route := test.route - pathTemplate := test.pathTemplate - if len(pathTemplate) == 0 { - pathTemplate = test.path - } - hostTemplate := test.hostTemplate - if len(hostTemplate) == 0 { - hostTemplate = test.host - } - - routePathTemplate, pathErr := route.GetPathTemplate() - if pathErr == nil && routePathTemplate != pathTemplate { - t.Errorf("(%v) GetPathTemplate not equal: expected %v, got %v", test.title, pathTemplate, routePathTemplate) - } - - routeHostTemplate, hostErr := route.GetHostTemplate() - if hostErr == nil && routeHostTemplate != hostTemplate { - t.Errorf("(%v) GetHostTemplate not equal: expected %v, got %v", test.title, hostTemplate, routeHostTemplate) - } -} - -func testMethods(t *testing.T, test routeTest) { - route := test.route - methods, _ := route.GetMethods() - if strings.Join(methods, ",") != strings.Join(test.methods, ",") { - t.Errorf("(%v) GetMethods not equal: expected %v, got %v", test.title, test.methods, methods) - } -} - -func testRegexp(t *testing.T, test routeTest) { - route := test.route - routePathRegexp, regexpErr := route.GetPathRegexp() - if test.pathRegexp != "" && regexpErr == nil && routePathRegexp != test.pathRegexp { - t.Errorf("(%v) GetPathRegexp not equal: expected %v, got %v", test.title, test.pathRegexp, routePathRegexp) - } -} - -func testQueriesRegexp(t *testing.T, test routeTest) { - route := test.route - queries, queriesErr := route.GetQueriesRegexp() - gotQueries := strings.Join(queries, ",") - if test.queriesRegexp != "" && queriesErr == nil && gotQueries != test.queriesRegexp { - t.Errorf("(%v) GetQueriesRegexp not equal: expected %v, got %v", test.title, test.queriesRegexp, gotQueries) - } -} - -func testQueriesTemplates(t *testing.T, test routeTest) { - route := test.route - queries, queriesErr := route.GetQueriesTemplates() - gotQueries := strings.Join(queries, ",") - if test.queriesTemplate != "" && queriesErr == nil && gotQueries != test.queriesTemplate { - t.Errorf("(%v) GetQueriesTemplates not equal: expected %v, got %v", test.title, test.queriesTemplate, gotQueries) - } -} - -type TestA301ResponseWriter struct { - hh http.Header - status int -} - -func (ho *TestA301ResponseWriter) Header() http.Header { - return ho.hh -} - -func (ho *TestA301ResponseWriter) Write(b []byte) (int, error) { - return 0, nil -} - -func (ho *TestA301ResponseWriter) WriteHeader(code int) { - ho.status = code -} - -func Test301Redirect(t *testing.T) { - m := make(http.Header) - - func1 := func(w http.ResponseWriter, r *http.Request) {} - func2 := func(w http.ResponseWriter, r *http.Request) {} - - r := NewRouter() - r.HandleFunc("/api/", func2).Name("func2") - r.HandleFunc("/", func1).Name("func1") - - req, _ := http.NewRequest("GET", "http://localhost//api/?abc=def", nil) - - res := TestA301ResponseWriter{ - hh: m, - status: 0, - } - r.ServeHTTP(&res, req) - - if "http://localhost/api/?abc=def" != res.hh["Location"][0] { - t.Errorf("Should have complete URL with query string") - } -} - -// methodsSubrouterTest models the data necessary for testing handler -// matching for subrouters created after HTTP methods matcher registration. -type methodsSubrouterTest struct { - title string - wantCode int - router *Router - // method is the input into the request and expected response - method string - // input request path - path string - // redirectTo is the expected location path for strict-slash matches - redirectTo string -} - -// methodHandler writes the method string in response. -func methodHandler(method string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(method)) - } -} - -// verify that copyRouteConf copies fields as expected. -func Test_copyRouteConf(t *testing.T) { - var ( - m MatcherFunc = func(*http.Request, *RouteMatch) bool { - return true - } - b BuildVarsFunc = func(i map[string]string) map[string]string { - return i - } - r, _ = newRouteRegexp("hi", regexpTypeHost, routeRegexpOptions{}) - ) - - tests := []struct { - name string - args routeConf - want routeConf - }{ - { - "empty", - routeConf{}, - routeConf{}, - }, - { - "full", - routeConf{ - useEncodedPath: true, - strictSlash: true, - skipClean: true, - regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}}, - matchers: []matcher{m}, - buildScheme: "https", - buildVarsFunc: b, - }, - routeConf{ - useEncodedPath: true, - strictSlash: true, - skipClean: true, - regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}}, - matchers: []matcher{m}, - buildScheme: "https", - buildVarsFunc: b, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // special case some incomparable fields of routeConf before delegating to reflect.DeepEqual - got := copyRouteConf(tt.args) - - // funcs not comparable, just compare length of slices - if len(got.matchers) != len(tt.want.matchers) { - t.Errorf("matchers different lengths: %v %v", len(got.matchers), len(tt.want.matchers)) - } - got.matchers, tt.want.matchers = nil, nil - - // deep equal treats nil slice differently to empty slice so check for zero len first - { - bothZero := len(got.regexp.queries) == 0 && len(tt.want.regexp.queries) == 0 - if !bothZero && !reflect.DeepEqual(got.regexp.queries, tt.want.regexp.queries) { - t.Errorf("queries unequal: %v %v", got.regexp.queries, tt.want.regexp.queries) - } - got.regexp.queries, tt.want.regexp.queries = nil, nil - } - - // funcs not comparable, just compare nullity - if (got.buildVarsFunc == nil) != (tt.want.buildVarsFunc == nil) { - t.Errorf("build vars funcs unequal: %v %v", got.buildVarsFunc == nil, tt.want.buildVarsFunc == nil) - } - got.buildVarsFunc, tt.want.buildVarsFunc = nil, nil - - // finish the deal - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("route confs unequal: %v %v", got, tt.want) - } - }) - } -} - -// mapToPairs converts a string map to a slice of string pairs -func mapToPairs(m map[string]string) []string { - var i int - p := make([]string, len(m)*2) - for k, v := range m { - p[i] = k - p[i+1] = v - i += 2 - } - return p -} - -// stringMapEqual checks the equality of two string maps -func stringMapEqual(m1, m2 map[string]string) bool { - nil1 := m1 == nil - nil2 := m2 == nil - if nil1 != nil2 || len(m1) != len(m2) { - return false - } - for k, v := range m1 { - if v != m2[k] { - return false - } - } - return true -} - -// stringHandler returns a handler func that writes a message 's' to the -// http.ResponseWriter. -func stringHandler(s string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(s)) - } -} - -// newRequest is a helper function to create a new request with a method and url. -// The request returned is a 'server' request as opposed to a 'client' one through -// simulated write onto the wire and read off of the wire. -// The differences between requests are detailed in the net/http package. -func newRequest(method, url string) *http.Request { - req, err := http.NewRequest(method, url, nil) - if err != nil { - panic(err) - } - // extract the escaped original host+path from url - // http://localhost/path/here?v=1#frag -> //localhost/path/here - opaque := "" - if i := len(req.URL.Scheme); i > 0 { - opaque = url[i+1:] - } - - if i := strings.LastIndex(opaque, "?"); i > -1 { - opaque = opaque[:i] - } - if i := strings.LastIndex(opaque, "#"); i > -1 { - opaque = opaque[:i] - } - - // Escaped host+path workaround as detailed in https://golang.org/pkg/net/url/#URL - // for < 1.5 client side workaround - req.URL.Opaque = opaque - - // Simulate writing to wire - var buff bytes.Buffer - req.Write(&buff) - ioreader := bufio.NewReader(&buff) - - // Parse request off of 'wire' - req, err = http.ReadRequest(ioreader) - if err != nil { - panic(err) - } - return req -} - -// create a new request with the provided headers -func newRequestWithHeaders(method, url string, headers ...string) *http.Request { - req := newRequest(method, url) - - if len(headers)%2 != 0 { - panic(fmt.Sprintf("Expected headers length divisible by 2 but got %v", len(headers))) - } - - for i := 0; i < len(headers); i += 2 { - req.Header.Set(headers[i], headers[i+1]) - } - - return req -} - -// newRequestHost a new request with a method, url, and host header -func newRequestHost(method, url, host string) *http.Request { - req := httptest.NewRequest(method, url, nil) - req.Host = host - return req -} diff --git a/pkg/routing/routing.go b/pkg/routing/routing.go index 48766aa8f..038487ca3 100644 --- a/pkg/routing/routing.go +++ b/pkg/routing/routing.go @@ -21,7 +21,6 @@ import ( "fmt" "net/http" "net/http/pprof" - "sort" "strings" "github.com/trickstercache/trickster/v2/pkg/backends" @@ -43,25 +42,32 @@ import ( po "github.com/trickstercache/trickster/v2/pkg/proxy/paths/options" "github.com/trickstercache/trickster/v2/pkg/proxy/request/rewriter" "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" "github.com/trickstercache/trickster/v2/pkg/util/middleware" ) // RegisterPprofRoutes will register the Pprof Debugging endpoints to the provided router -func RegisterPprofRoutes(routerName string, h *http.ServeMux, logger interface{}) { +func RegisterPprofRoutes(routerName string, r router.Router, logger interface{}) { tl.Info(logger, "registering pprof /debug routes", tl.Pairs{"routerName": routerName}) - h.HandleFunc("/debug/pprof/", pprof.Index) - h.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - h.HandleFunc("/debug/pprof/profile", pprof.Profile) - h.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - h.HandleFunc("/debug/pprof/trace", pprof.Trace) + r.RegisterRoute("/debug/pprof/", nil, nil, + false, http.HandlerFunc(pprof.Index)) + r.RegisterRoute("/debug/pprof/cmdline", nil, nil, + false, http.HandlerFunc(pprof.Cmdline)) + r.RegisterRoute("/debug/pprof/profile", nil, nil, + false, http.HandlerFunc(pprof.Profile)) + r.RegisterRoute("/debug/pprof/symbol", nil, nil, + false, http.HandlerFunc(pprof.Symbol)) + r.RegisterRoute("/debug/pprof/trace", nil, nil, + false, http.HandlerFunc(pprof.Trace)) } // RegisterProxyRoutes iterates the Trickster Configuration and // registers the routes for the configured backends -func RegisterProxyRoutes(conf *config.Config, r router.Router, metricsRouter *http.ServeMux, - caches map[string]cache.Cache, tracers tracing.Tracers, - logger interface{}, dryRun bool) (backends.Backends, error) { +func RegisterProxyRoutes(conf *config.Config, r router.Router, + metricsRouter router.Router, caches map[string]cache.Cache, + tracers tracing.Tracers, logger interface{}, + dryRun bool) (backends.Backends, error) { // a fake "top-level" backend representing the main frontend, so rules can route // to it via the clients map @@ -146,13 +152,15 @@ var noCacheBackends = map[string]interface{}{ } // RegisterHealthHandler registers the main health handler -func RegisterHealthHandler(router *http.ServeMux, path string, hc healthcheck.HealthChecker) { - router.Handle(path, health.StatusHandler(hc)) +func RegisterHealthHandler(router router.Router, path string, + hc healthcheck.HealthChecker) { + router.RegisterRoute(path, nil, nil, false, health.StatusHandler(hc)) } -func registerBackendRoutes(r router.Router, metricsRouter *http.ServeMux, conf *config.Config, k string, - o *bo.Options, clients backends.Backends, caches map[string]cache.Cache, - tracers tracing.Tracers, logger interface{}, dryRun bool) error { +func registerBackendRoutes(r router.Router, metricsRouter router.Router, + conf *config.Config, k string, o *bo.Options, clients backends.Backends, + caches map[string]cache.Cache, tracers tracing.Tracers, logger interface{}, + dryRun bool) error { var client backends.Backend var c cache.Cache @@ -172,7 +180,7 @@ func registerBackendRoutes(r router.Router, metricsRouter *http.ServeMux, conf * cf := registration.SupportedProviders() if f, ok := cf[strings.ToLower(o.Provider)]; ok && f != nil { - client, err = f(k, o, router.NewRouter(), c, clients, cf) + client, err = f(k, o, lm.NewRouter(), c, clients, cf) } if err != nil { return err @@ -196,7 +204,9 @@ func registerBackendRoutes(r router.Router, metricsRouter *http.ServeMux, conf * tl.Pairs{"path": hp, "backendName": o.Name, "upstreamPath": o.HealthCheck.Path, "upstreamVerb": o.HealthCheck.Verb}) - metricsRouter.Handle(hp, http.Handler(middleware.WithResourcesContext(client, o, nil, nil, nil, logger, h))) + metricsRouter.RegisterRoute(hp, nil, nil, false, + http.Handler(middleware.WithResourcesContext(client, o, nil, + nil, nil, logger, h))) } } return nil @@ -207,7 +217,7 @@ func registerBackendRoutes(r router.Router, metricsRouter *http.ServeMux, conf * // the path routes to the appropriate handler from the provided handlers map func RegisterPathRoutes(r router.Router, handlers map[string]http.Handler, client backends.Backend, o *bo.Options, c cache.Cache, - defaultPaths map[string]*po.Options, tracers tracing.Tracers, + paths map[string]*po.Options, tracers tracing.Tracers, healthHandlerPath string, logger interface{}) { if o == nil { @@ -247,35 +257,24 @@ func RegisterPathRoutes(r router.Router, handlers map[string]http.Handler, return h } - // This takes the default paths, named like '/api/v1/query' and morphs the name - // into what the router wants, with methods like '/api/v1/query-0000011001', to help - // route sorting. the bitmap provides unique names multiple path entries of the same - // path but with different methods, without impacting true path sorting - pathsWithVerbs := make(map[string]*po.Options) - for _, p := range defaultPaths { - if len(p.Methods) == 0 { - p.Methods = methods.CacheableHTTPMethods() - } - pathsWithVerbs[p.Path+"-"+fmt.Sprintf("%010b", methods.MethodMask(p.Methods...))] = p - } - - // now we will iterate through the configured paths, and overlay them on those default paths. - // for rule & alb backend providers, only the default paths are used with no overlay or importable config + // now we will iterate through the configured paths, and overlay them on + // those default paths. for rule & alb backend providers, only the default + // paths are used with no overlay or importable config if !backends.IsVirtual(o.Provider) { for k, p := range o.Paths { - if p2, ok := pathsWithVerbs[k]; ok { + if p2, ok := paths[k]; ok { p2.Merge(p) continue } p3 := po.New() p3.Merge(p) - pathsWithVerbs[k] = p3 + paths[k] = p3 } } - plist := make([]string, 0, len(pathsWithVerbs)) - deletes := make([]string, 0, len(pathsWithVerbs)) - for k, p := range pathsWithVerbs { + plist := make([]string, 0, len(paths)) + deletes := make([]string, 0, len(paths)) + for k, p := range paths { if h, ok := handlers[p.HandlerName]; ok && h != nil { p.Handler = h plist = append(plist, k) @@ -286,19 +285,12 @@ func RegisterPathRoutes(r router.Router, handlers map[string]http.Handler, } } for _, p := range deletes { - delete(pathsWithVerbs, p) + delete(paths, p) } - - sort.Sort(ByLen(plist)) - for i := len(plist)/2 - 1; i >= 0; i-- { - opp := len(plist) - 1 - i - plist[i], plist[opp] = plist[opp], plist[i] - } - or := client.Router().(router.Router) for _, v := range plist { - p := pathsWithVerbs[v] + p := paths[v] pathPrefix := "/" + o.Name handledPath := pathPrefix + p.Path @@ -312,42 +304,27 @@ func RegisterPathRoutes(r router.Router, handlers map[string]http.Handler, if p.Methods[0] == "*" { p.Methods = methods.AllHTTPMethods() } - - switch p.MatchType { - case matching.PathMatchTypePrefix: - // Case where we path match by prefix - // Host Header Routing - for _, h := range o.Hosts { - r.PathPrefix(p.Path).Handler(decorate(p)).Methods(p.Methods...).Host(h) - } - if !o.PathRoutingDisabled { - // Path Routing - r.PathPrefix(handledPath).Handler(middleware.StripPathPrefix(pathPrefix, - decorate(p))).Methods(p.Methods...) - } - or.PathPrefix(p.Path).Handler(decorate(p)).Methods(p.Methods...) - default: - // default to exact match - // Host Header Routing - for _, h := range o.Hosts { - r.Handle(p.Path, decorate(p)).Methods(p.Methods...).Host(h) - } - if !o.PathRoutingDisabled { - // Path Routing - r.Handle(handledPath, middleware.StripPathPrefix(pathPrefix, - decorate(p))).Methods(p.Methods...) - } - or.Handle(p.Path, decorate(p)).Methods(p.Methods...) + if len(o.Hosts) > 0 { + r.RegisterRoute(p.Path, o.Hosts, p.Methods, + p.MatchType == matching.PathMatchTypePrefix, decorate(p)) } + if !o.PathRoutingDisabled { + r.RegisterRoute(handledPath, nil, p.Methods, + p.MatchType == matching.PathMatchTypePrefix, + middleware.StripPathPrefix(pathPrefix, decorate(p))) + } + or.RegisterRoute(p.Path, nil, p.Methods, + p.MatchType == matching.PathMatchTypePrefix, + decorate(p)) } } o.Router = or - o.Paths = pathsWithVerbs + o.Paths = paths } // RegisterDefaultBackendRoutes will iterate the Backends and register the default routes -func RegisterDefaultBackendRoutes(router router.Router, bknds backends.Backends, +func RegisterDefaultBackendRoutes(r router.Router, bknds backends.Backends, logger interface{}, tracers tracing.Tracers) { decorate := func(o *bo.Options, po *po.Options, tr *tracing.Tracer, @@ -384,50 +361,23 @@ func RegisterDefaultBackendRoutes(router router.Router, bknds backends.Backends, tl.Info(logger, "registering default backend handler paths", tl.Pairs{"backendName": o.Name}) - // Sort by key length(Path length) to ensure /api/v1/query_range appear before /api/v1 or / path in regex path matching - keylist := make([]string, 0, len(o.Paths)) - for key := range o.Paths { - keylist = append(keylist, key) - } - sort.Sort(ByLen(keylist)) - for i := len(keylist)/2 - 1; i >= 0; i-- { - opp := len(keylist) - 1 - i - keylist[i], keylist[opp] = keylist[opp], keylist[i] - } - - for _, k := range keylist { - var p = o.Paths[k] + for _, p := range o.Paths { if p.Handler != nil && len(p.Methods) > 0 { - tl.Debug(logger, "registering default backend handler paths", - tl.Pairs{"backendName": o.Name, "path": p.Path, "handlerName": p.HandlerName, - "matchType": p.MatchType}) - switch p.MatchType { - case matching.PathMatchTypePrefix: - // Case where we path match by prefix - router.PathPrefix(p.Path).Handler(decorate(o, p, tr, b.Cache(), b)).Methods(p.Methods...) - default: - // default to exact match - router.Handle(p.Path, decorate(o, p, tr, b.Cache(), b)).Methods(p.Methods...) + tl.Debug(logger, + "registering default backend handler paths", + tl.Pairs{"backendName": o.Name, "path": p.Path, + "handlerName": p.HandlerName, + "matchType": p.MatchType}) + + if p.MatchType == matching.PathMatchTypePrefix { + r.RegisterRoute(p.Path, nil, p.Methods, + true, decorate(o, p, tr, b.Cache(), b)) } - router.Handle(p.Path, decorate(o, p, tr, b.Cache(), b)).Methods(p.Methods...) + r.RegisterRoute(p.Path, nil, p.Methods, + false, decorate(o, p, tr, b.Cache(), b)) } } } } } - -// ByLen allows sorting of a string slice by string length -type ByLen []string - -func (a ByLen) Len() int { - return len(a) -} - -func (a ByLen) Less(i, j int) bool { - return len(a[i]) < len(a[j]) -} - -func (a ByLen) Swap(i, j int) { - a[i], a[j] = a[j], a[i] -} diff --git a/pkg/routing/routing_test.go b/pkg/routing/routing_test.go index 71dd56e8b..a5c71c18e 100644 --- a/pkg/routing/routing_test.go +++ b/pkg/routing/routing_test.go @@ -38,24 +38,24 @@ import ( "github.com/trickstercache/trickster/v2/pkg/proxy/methods" "github.com/trickstercache/trickster/v2/pkg/proxy/paths/matching" po "github.com/trickstercache/trickster/v2/pkg/proxy/paths/options" - "github.com/trickstercache/trickster/v2/pkg/router" + "github.com/trickstercache/trickster/v2/pkg/router/lm" testutil "github.com/trickstercache/trickster/v2/pkg/testutil" tlstest "github.com/trickstercache/trickster/v2/pkg/testutil/tls" ) func TestRegisterPprofRoutes(t *testing.T) { - router := http.NewServeMux() + router := lm.NewRouter() log := logging.ConsoleLogger("info") RegisterPprofRoutes("test", router, log) r, _ := http.NewRequest("GET", "http://0/debug/pprof", nil) - _, p := router.Handler(r) - if p != "/debug/pprof/" { - t.Error("expected pprof route path") + h := router.Handler(r) + if h == nil { + t.Error("expected non-nil handler") } } func TestRegisterHealthHandler(t *testing.T) { - router := http.NewServeMux() + router := lm.NewRouter() path := "/test" hc := healthcheck.New() RegisterHealthHandler(router, path, hc) @@ -73,7 +73,7 @@ func TestRegisterProxyRoutes(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, nil, log, false) + proxyClients, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, log, false) if err != nil { t.Error(err) } @@ -88,7 +88,7 @@ func TestRegisterProxyRoutes(t *testing.T) { o.Hosts = []string{"test", "test2"} registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) - RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, tr, log, false) + RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, tr, log, false) if len(proxyClients) == 0 { t.Errorf("expected %d got %d", 1, 0) @@ -108,34 +108,34 @@ func TestRegisterProxyRoutes(t *testing.T) { conf.Backends["2"] = o2 - router := router.NewRouter() - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + router := lm.NewRouter() + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err == nil { t.Error("Expected error for too many default backends.") } o1.IsDefault = false o1.CacheName = "invalid" - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err == nil { t.Errorf("Expected error for invalid cache name") } o1.CacheName = o2.CacheName - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err != nil { t.Error(err) } o2.IsDefault = false o2.CacheName = "invalid" - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err == nil { t.Errorf("Expected error for invalid cache name") } o2.CacheName = "default" - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err != nil { t.Error(err) } @@ -148,9 +148,9 @@ func TestRegisterProxyRoutes(t *testing.T) { conf.Backends["1"] = o1 delete(conf.Backends, "default") - o1.Paths["/-0000000011"].Methods = nil + o1.Paths["/-GET-HEAD"].Methods = nil - _, err = RegisterProxyRoutes(conf, router, http.NewServeMux(), caches, tr, log, false) + _, err = RegisterProxyRoutes(conf, router, lm.NewRouter(), caches, tr, log, false) if err != nil { t.Error(err) } @@ -166,7 +166,7 @@ func TestRegisterProxyRoutesInflux(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -187,7 +187,7 @@ func TestRegisterProxyRoutesReverseProxy(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -209,7 +209,7 @@ func TestRegisterProxyRoutesClickHouse(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -233,7 +233,7 @@ func TestRegisterProxyRoutesALB(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -255,7 +255,7 @@ func TestRegisterProxyRoutesIRONdb(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -280,7 +280,7 @@ func TestRegisterProxyRoutesWithReqRewriters(t *testing.T) { caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - proxyClients, err := RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + proxyClients, err := RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -302,7 +302,7 @@ func TestRegisterProxyRoutesMultipleDefaults(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err == nil { t.Errorf("expected error `%s` got nothing", expected1) @@ -346,7 +346,7 @@ func TestRegisterProxyRoutesInvalidCert(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err == nil { t.Errorf("expected error: %s", expected) @@ -376,7 +376,7 @@ func TestRegisterProxyRoutesBadProvider(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err == nil { t.Errorf("expected error `%s` got nothing", expected) @@ -393,7 +393,7 @@ func TestRegisterMultipleBackends(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -408,7 +408,7 @@ func TestRegisterMultipleBackendsPlusDefault(t *testing.T) { } caches := registration.LoadCachesFromConfig(conf, logging.ConsoleLogger("error")) defer registration.CloseCaches(caches) - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err != nil { t.Error(err) @@ -429,7 +429,7 @@ func TestRegisterPathRoutes(t *testing.T) { } oo := conf.Backends["default"] - rpc, _ := reverseproxycache.NewClient("test", oo, router.NewRouter(), nil, nil, nil) + rpc, _ := reverseproxycache.NewClient("test", oo, lm.NewRouter(), nil, nil, nil) dpc := rpc.DefaultPathConfigs(oo) dpc["/-GET-HEAD"].Methods = nil @@ -438,7 +438,7 @@ func TestRegisterPathRoutes(t *testing.T) { RegisterPathRoutes(nil, handlers, rpc, oo, nil, dpc, nil, "", logging.ConsoleLogger("INFO")) - router := router.NewRouter() + router := lm.NewRouter() dpc = rpc.DefaultPathConfigs(oo) dpc["/-GET-HEAD"].Methods = []string{"*"} dpc["/-GET-HEAD"].Handler = testHandler @@ -470,7 +470,7 @@ func TestValidateRuleClients(t *testing.T) { o := conf.Backends["default"] o.Provider = "rule" - _, err = RegisterProxyRoutes(conf, router.NewRouter(), http.NewServeMux(), caches, + _, err = RegisterProxyRoutes(conf, lm.NewRouter(), lm.NewRouter(), caches, nil, logging.ConsoleLogger("info"), false) if err == nil { t.Error("expected error") @@ -482,7 +482,7 @@ func TestRegisterDefaultBackendRoutes(t *testing.T) { // successful passing of this test is no panic - r := router.NewRouter() + r := lm.NewRouter() conf := config.NewConfig() oo := conf.Backends["default"] w := httptest.NewRecorder() @@ -497,7 +497,7 @@ func TestRegisterDefaultBackendRoutes(t *testing.T) { oo.TracingConfigName = "testTracer" oo.Paths = map[string]*po.Options{"root": po1} oo.IsDefault = true - rpc, _ := reverseproxycache.NewClient("default", oo, router.NewRouter(), nil, nil, nil) + rpc, _ := reverseproxycache.NewClient("default", oo, lm.NewRouter(), nil, nil, nil) b := backends.Backends{"default": rpc} tr := tracing.Tracers{"testTracer": testutil.NewTestTracer()} @@ -507,7 +507,7 @@ func TestRegisterDefaultBackendRoutes(t *testing.T) { po1.ReqRewriter = ri RegisterDefaultBackendRoutes(r, b, logger, tr) - r = router.NewRouter() + r = lm.NewRouter() po1.MatchType = matching.PathMatchTypeExact RegisterDefaultBackendRoutes(r, b, logger, tr) diff --git a/pkg/testutil/writer/writer.go b/pkg/testutil/writer/writer.go new file mode 100644 index 000000000..bf99839b0 --- /dev/null +++ b/pkg/testutil/writer/writer.go @@ -0,0 +1,52 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// package writer represents a Test ResponseWriter for use in Unit Tests +package writer + +import "net/http" + +func NewWriter() http.ResponseWriter { + return &TestResponseWriter{ + Headers: make(http.Header), + Bytes: make([]byte, 0, 8192), + } +} + +type TestResponseWriter struct { + Headers http.Header + StatusCode int + Bytes []byte +} + +func (w *TestResponseWriter) Header() http.Header { + return w.Headers +} + +func (w *TestResponseWriter) WriteHeader(statusCode int) { + w.StatusCode = statusCode +} + +func (w *TestResponseWriter) Write(b []byte) (int, error) { + w.Bytes = append(w.Bytes, b...) + return len(b), nil +} + +func (w *TestResponseWriter) Reset() { + w.Headers = make(http.Header) + w.StatusCode = 0 + w.Bytes = make([]byte, 0, 8192) +}