diff --git a/martian.go b/martian.go index 0955510..d6a64a6 100644 --- a/martian.go +++ b/martian.go @@ -27,6 +27,8 @@ import ( _ "github.com/google/martian/status" ) +type requestExecutorFactory func(*config.Backend) client.HTTPRequestExecutor + // NewBackendFactory creates a proxy.BackendFactory with the martian request executor wrapping the injected one. // If there is any problem parsing the extra config data, it just uses the injected request executor. func NewBackendFactory(logger logging.Logger, re client.HTTPRequestExecutor) proxy.BackendFactory { @@ -35,23 +37,31 @@ func NewBackendFactory(logger logging.Logger, re client.HTTPRequestExecutor) pro // NewConfiguredBackendFactory creates a proxy.BackendFactory with the martian request executor wrapping the injected one. // If there is any problem parsing the extra config data, it just uses the injected request executor. -func NewConfiguredBackendFactory(logger logging.Logger, ref func(*config.Backend) client.HTTPRequestExecutor) proxy.BackendFactory { - parse.Register("static.Modifier", staticModifierFromJSON) - +func NewConfiguredBackendFactory(logger logging.Logger, ref requestExecutorFactory) proxy.BackendFactory { + ref = NewRequestExecutorFactory(logger, ref) return func(remote *config.Backend) proxy.Proxy { + return proxy.NewHTTPProxyWithHTTPExecutor(remote, ref(remote), remote.Decoder) + } +} + +// NewRequestExecutorFactory creates a request executor factory that takes as input the config.Backend wrapping the injected one. +// If there is any problem parsing the extra config data, it just uses the injected request executor. +func NewRequestExecutorFactory(logger logging.Logger, ref requestExecutorFactory) requestExecutorFactory { + parse.Register("static.Modifier", staticModifierFromJSON) + return func(remote *config.Backend) client.HTTPRequestExecutor { re := ref(remote) result, ok := ConfigGetter(remote.ExtraConfig).(Result) if !ok { - return proxy.NewHTTPProxyWithHTTPExecutor(remote, re, remote.Decoder) + return re } switch result.Err { case nil: - return proxy.NewHTTPProxyWithHTTPExecutor(remote, HTTPRequestExecutor(result.Result, re), remote.Decoder) + return HTTPRequestExecutor(result.Result, re) case ErrEmptyValue: - return proxy.NewHTTPProxyWithHTTPExecutor(remote, re, remote.Decoder) + return re default: logger.Error(result, remote.ExtraConfig) - return proxy.NewHTTPProxyWithHTTPExecutor(remote, re, remote.Decoder) + return re } } } diff --git a/martian_test.go b/martian_test.go index 56ec6b8..556044c 100644 --- a/martian_test.go +++ b/martian_test.go @@ -15,6 +15,7 @@ import ( "github.com/luraproject/lura/config" "github.com/luraproject/lura/logging" "github.com/luraproject/lura/proxy" + "github.com/luraproject/lura/transport/http/client" ) func TestHTTPRequestExecutor_ok(t *testing.T) { @@ -188,6 +189,88 @@ func TestHTTPRequestExecutor_koErroredRequest(t *testing.T) { } } +func TestNewRequestExecutorFactory_noExtra(t *testing.T) { + expectedErr := fmt.Errorf("some error") + re := func(_ context.Context, _ *http.Request) (resp *http.Response, err error) { return nil, expectedErr } + ref := func(_ *config.Backend) client.HTTPRequestExecutor { return re } + buf := bytes.NewBuffer(make([]byte, 1024)) + l, _ := logging.NewLogger("DEBUG", buf, "") + ref = NewRequestExecutorFactory(l, ref) + re = ref(&config.Backend{}) + req, _ := http.NewRequest("GET", "http://example.com/", ioutil.NopCloser(bytes.NewBufferString(""))) + resp, err := re(context.Background(), req) + if resp != nil { + t.Error("unexpected response:", resp) + } + if err != expectedErr { + t.Error("unexpected error:", err) + } +} + +func TestNewRequestExecutorFactory_wrongExtra(t *testing.T) { + expectedErr := fmt.Errorf("some error") + re := func(_ context.Context, _ *http.Request) (resp *http.Response, err error) { return nil, expectedErr } + ref := func(_ *config.Backend) client.HTTPRequestExecutor { return re } + buf := bytes.NewBuffer(make([]byte, 1024)) + l, _ := logging.NewLogger("DEBUG", buf, "") + ref = NewRequestExecutorFactory(l, ref) + re = ref(&config.Backend{ + ExtraConfig: config.ExtraConfig{ + Namespace: 42, + }, + }) + req, _ := http.NewRequest("GET", "http://example.com/", ioutil.NopCloser(bytes.NewBufferString(""))) + resp, err := re(context.Background(), req) + if resp != nil { + t.Error("unexpected response:", resp) + } + if err != expectedErr { + t.Error("unexpected error:", err) + } +} + +func TestNewRequestExecutorFactory_ok(t *testing.T) { + expectedHdr := "ouh yeah!" + expectedErr := fmt.Errorf("some error") + re := func(_ context.Context, r *http.Request) (resp *http.Response, err error) { + if header := r.Header.Get("X-Martian"); header != expectedHdr { + t.Error("unexpected request header:", header) + } + return nil, expectedErr + } + ref := func(_ *config.Backend) client.HTTPRequestExecutor { return re } + buf := bytes.NewBuffer(make([]byte, 1024)) + l, _ := logging.NewLogger("DEBUG", buf, "") + ref = NewRequestExecutorFactory(l, ref) + re = ref(&config.Backend{ + ExtraConfig: config.ExtraConfig{ + Namespace: map[string]interface{}{ + "fifo.Group": map[string]interface{}{ + "scope": []interface{}{"request", "response"}, + "aggregateErrors": true, + "modifiers": []map[string]interface{}{ + { + "header.Modifier": map[string]interface{}{ + "scope": []interface{}{"request", "response"}, + "name": "X-Martian", + "value": expectedHdr, + }, + }, + }, + }, + }, + }, + }) + req, _ := http.NewRequest("GET", "http://example.com/", ioutil.NopCloser(bytes.NewBufferString(""))) + resp, err := re(context.Background(), req) + if resp != nil { + t.Error("unexpected response:", resp) + } + if err != expectedErr { + t.Error("unexpected error:", err) + } +} + func TestNewBackendFactory_noExtra(t *testing.T) { expectedErr := fmt.Errorf("some error") re := func(_ context.Context, _ *http.Request) (resp *http.Response, err error) { return nil, expectedErr }