From b6b93b5ecd9b060f43e6aa397582b724fec6fa75 Mon Sep 17 00:00:00 2001 From: Jacky Hu Date: Thu, 24 Oct 2024 16:59:49 -0700 Subject: [PATCH] [PECO-2050] Add custom auth headers into cloud fetch request --- internal/rows/arrowbased/batchloader.go | 6 ++++++ internal/rows/arrowbased/batchloader_test.go | 21 +++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 45b067d..4f7ef0b 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -277,6 +277,12 @@ func fetchBatchBytes( return nil, err } + if link.HttpHeaders != nil { + for key, value := range link.HttpHeaders { + req.Header.Set(key, value) + } + } + client := http.DefaultClient res, err := client.Do(req) if err != nil { diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index e47eef0..d02d299 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,14 +4,16 @@ import ( "bytes" "context" "fmt" - dbsqlerr "github.com/databricks/databricks-sql-go/errors" - "github.com/databricks/databricks-sql-go/internal/cli_service" - "github.com/databricks/databricks-sql-go/internal/config" "net/http" "net/http/httptest" "testing" "time" + dbsqlerr "github.com/databricks/databricks-sql-go/errors" + "github.com/databricks/databricks-sql-go/internal/cli_service" + "github.com/databricks/databricks-sql-go/internal/config" + "github.com/pkg/errors" + "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" @@ -28,8 +30,19 @@ func TestCloudFetchIterator(t *testing.T) { defer server.Close() t.Run("should fetch all the links", func(t *testing.T) { + cloudFetchHeaders := map[string]string{ + "foo": "bar", + } + handler = func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) + for name, value := range cloudFetchHeaders { + if values, ok := r.Header[name]; ok { + if values[0] != value { + panic(errors.New("Missing auth headers")) + } + } + } _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) if err != nil { panic(err) @@ -44,12 +57,14 @@ func TestCloudFetchIterator(t *testing.T) { ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), StartRowOffset: startRowOffset, RowCount: 1, + HttpHeaders: cloudFetchHeaders, }, { FileLink: server.URL, ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), StartRowOffset: startRowOffset + 1, RowCount: 1, + HttpHeaders: cloudFetchHeaders, }, }