From 8968507a3ad0551338ab7eb0b5b26c786e410bdb Mon Sep 17 00:00:00 2001 From: Nathan Yergler Date: Wed, 29 Nov 2023 17:39:57 -0800 Subject: [PATCH 1/2] Use the context logger, if available. This allows use of a non-global logger for finer grained control. --- connection.go | 9 ++++++--- connector.go | 2 +- doc.go | 12 ++++++++++++ internal/client/client.go | 2 +- internal/rows/rows.go | 6 +++--- internal/rows/rows_test.go | 12 ++++++------ logger/logger.go | 18 +++++++++++++++++- 7 files changed, 46 insertions(+), 15 deletions(-) diff --git a/connection.go b/connection.go index 87eb78a..5527b60 100644 --- a/connection.go +++ b/connection.go @@ -121,6 +121,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name stagingErr := c.execStagingOperation(exStmtResp, ctx) if exStmtResp != nil && exStmtResp.OperationHandle != nil { + // we have an operation id so update the logger + log, _ := client.LoggerAndContext(ctx, exStmtResp) + // since we have an operation handle we can close the operation if necessary alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -175,7 +178,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } corrId := driverctx.CorrelationIdFromContext(ctx) - rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + rows, err := rows.NewRows(ctx, c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) return rows, err @@ -348,7 +351,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) + log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp ctx = driverctx.NewContextWithConnId(ctx, c.id) newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -567,7 +570,7 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + row, err = rows.NewRows(ctx, c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } diff --git a/connector.go b/connector.go index 96a8831..ca86d7e 100644 --- a/connector.go +++ b/connector.go @@ -61,7 +61,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { client: tclient, session: session, } - log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") + log := logger.AddContext(logger.Ctx(ctx), conn.id, driverctx.CorrelationIdFromContext(ctx), "") log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath) diff --git a/doc.go b/doc.go index 9463d77..f6ea19f 100644 --- a/doc.go +++ b/doc.go @@ -154,6 +154,18 @@ The result log may look like this: {"level":"debug","connId":"01ed6545-5669-1ec7-8c7e-6d8a1ea0ab16","corrId":"workflow-example","queryId":"01ed6545-57cc-188a-bfc5-d9c0eaf8e189","time":1668558402,"message":"Run Main elapsed time: 1.298712292s"} +You may customize the log by passing it using Zerolog's context support. This allows customziation of the output, as well as inclusion of additionl metadata. + +For example, + + log := zerolog.New(DefaultLogOutput).With("service_id", "workflow-example")).Logger() + ctx = log.WithContext(context.Background()) + ... + db, err := sql.Open("databricks", "") + ... + rows, err := db.QueryContext(ctx, `select * from sometable`) + ... + # Programmatically Retrieving Connection and Query Id Use the driverctx package under driverctx/ctx.go to add callbacks to the query context to receive the connection id and query id. diff --git a/internal/client/client.go b/internal/client/client.go index fda1053..71d9096 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -403,7 +403,7 @@ func LoggerAndContext(ctx context.Context, c any) (*logger.DBSQLLogger, context. queryId = guidFromHasOpHandle(c) ctx = driverctx.NewContextWithQueryId(ctx, queryId) } - log := logger.WithContext(connId, corrId, queryId) + log := logger.AddContext(logger.Ctx(ctx), connId, corrId, queryId) return log, ctx } diff --git a/internal/rows/rows.go b/internal/rows/rows.go index c9581e2..ce34f3e 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -68,6 +68,7 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil) var _ dbsqlrows.Rows = (*rows)(nil) func NewRows( + ctx context.Context, connId string, correlationId string, opHandle *cli_service.TOperationHandle, @@ -77,12 +78,11 @@ func NewRows( ) (driver.Rows, dbsqlerr.DBError) { var logger *dbsqllog.DBSQLLogger - var ctx context.Context if opHandle != nil { - logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) + logger = dbsqllog.AddContext(dbsqllog.Ctx(ctx), connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) } else { - logger = dbsqllog.WithContext(connId, correlationId, "") + logger = dbsqllog.AddContext(dbsqllog.Ctx(ctx), connId, correlationId, "") ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId) } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index c83fc41..f5168ab 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -413,7 +413,7 @@ func TestColumnsWithDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) - d, err := NewRows("", "", nil, client, nil, nil) + d, err := NewRows(context.Background(), "", "", nil, client, nil, nil) assert.Nil(t, err) rowSet := d.(*rows) @@ -708,7 +708,7 @@ func TestRowsCloseOptimization(t *testing.T) { } opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}} - rowSet, _ := NewRows("", "", opHandle, client, nil, nil) + rowSet, _ := NewRows(context.Background(), "", "", opHandle, client, nil, nil) // rowSet has no direct results calling Close should result in call to client to close operation err := rowSet.Close() @@ -721,7 +721,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } closeCount = 0 - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows(context.Background(), "", "", opHandle, client, nil, directResults) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 1, closeCount) @@ -734,7 +734,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}}, ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows(context.Background(), "", "", opHandle, client, nil, directResults) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 0, closeCount) @@ -799,7 +799,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows(context.Background(), "connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -869,7 +869,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows(context.Background(), "connId", "corrId", nil, client, cfg, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) diff --git a/logger/logger.go b/logger/logger.go index 683501a..6bbbd9b 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "context" "io" "os" "runtime" @@ -123,9 +124,24 @@ func Err(err error) *zerolog.Event { return Logger.Err(err) } +// Ctx returns a DBSQLLogger from the provided context. If no logger is found, +// the default logger is returned. +func Ctx(ctx context.Context) *DBSQLLogger { + l := zerolog.Ctx(ctx) + if l == zerolog.DefaultContextLogger { + return Logger + } + return &DBSQLLogger{*l} +} + +// AddContext sets connectionId, correlationId, and queryId as fields on the provided logger. +func AddContext(l *DBSQLLogger, connectionId string, correlationId string, queryId string) *DBSQLLogger { + return &DBSQLLogger{l.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} +} + // WithContext sets connectionId, correlationId, and queryId to be used as fields. func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger { - return &DBSQLLogger{Logger.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} + return AddContext(Logger, connectionId, correlationId, queryId) } // Track is a convenience function to track time spent From e550b7fd02a1d188d4a77f0f196f1327269be335 Mon Sep 17 00:00:00 2001 From: Nathan Yergler Date: Mon, 13 May 2024 11:21:32 -0700 Subject: [PATCH 2/2] Update based on PR feedback. - Don't pass Context directly into NewRows - Rename function for retrieving logger from Context --- connection.go | 6 +++--- connector.go | 2 +- internal/client/client.go | 2 +- internal/rows/rows.go | 8 ++++---- internal/rows/rows_test.go | 16 ++++++++-------- logger/logger.go | 11 +++++++---- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/connection.go b/connection.go index c5bb4ed..b3ee948 100644 --- a/connection.go +++ b/connection.go @@ -170,7 +170,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } corrId := driverctx.CorrelationIdFromContext(ctx) - rows, err := rows.NewRows(ctx, c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx)) return rows, err @@ -343,7 +343,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) + log := logger.AddContext(logger.FromContext(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp ctx = driverctx.NewContextWithConnId(ctx, c.id) newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -562,7 +562,7 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err = rows.NewRows(ctx, c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx)) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } diff --git a/connector.go b/connector.go index ca86d7e..a947030 100644 --- a/connector.go +++ b/connector.go @@ -61,7 +61,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { client: tclient, session: session, } - log := logger.AddContext(logger.Ctx(ctx), conn.id, driverctx.CorrelationIdFromContext(ctx), "") + log := logger.AddContext(logger.FromContext(ctx), conn.id, driverctx.CorrelationIdFromContext(ctx), "") log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath) diff --git a/internal/client/client.go b/internal/client/client.go index 71d9096..36bdd54 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -403,7 +403,7 @@ func LoggerAndContext(ctx context.Context, c any) (*logger.DBSQLLogger, context. queryId = guidFromHasOpHandle(c) ctx = driverctx.NewContextWithQueryId(ctx, queryId) } - log := logger.AddContext(logger.Ctx(ctx), connId, corrId, queryId) + log := logger.AddContext(logger.FromContext(ctx), connId, corrId, queryId) return log, ctx } diff --git a/internal/rows/rows.go b/internal/rows/rows.go index ce34f3e..cc77c20 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -68,21 +68,21 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil) var _ dbsqlrows.Rows = (*rows)(nil) func NewRows( - ctx context.Context, connId string, correlationId string, opHandle *cli_service.TOperationHandle, client cli_service.TCLIService, config *config.Config, directResults *cli_service.TSparkDirectResults, + logger *dbsqllog.DBSQLLogger, ) (driver.Rows, dbsqlerr.DBError) { - var logger *dbsqllog.DBSQLLogger + var ctx context.Context if opHandle != nil { - logger = dbsqllog.AddContext(dbsqllog.Ctx(ctx), connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) + logger = dbsqllog.AddContext(logger, connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) } else { - logger = dbsqllog.AddContext(dbsqllog.Ctx(ctx), connId, correlationId, "") + logger = dbsqllog.AddContext(logger, connId, correlationId, "") ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId) } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index 1916bfe..e3d724f 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -413,7 +413,7 @@ func TestColumnsWithDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) - d, err := NewRows(context.Background(), "", "", nil, client, nil, nil) + d, err := NewRows("", "", nil, client, nil, nil, nil) assert.Nil(t, err) rowSet := d.(*rows) @@ -708,7 +708,7 @@ func TestRowsCloseOptimization(t *testing.T) { } opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}} - rowSet, _ := NewRows(context.Background(), "", "", opHandle, client, nil, nil) + rowSet, _ := NewRows("", "", opHandle, client, nil, nil, nil) // rowSet has no direct results calling Close should result in call to client to close operation err := rowSet.Close() @@ -721,7 +721,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } closeCount = 0 - rowSet, _ = NewRows(context.Background(), "", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 1, closeCount) @@ -734,7 +734,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}}, ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } - rowSet, _ = NewRows(context.Background(), "", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 0, closeCount) @@ -799,7 +799,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) cfg := config.WithDefaults() - rows, err := NewRows(context.Background(), "connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -869,7 +869,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) cfg := config.WithDefaults() - rows, err := NewRows(context.Background(), "connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -927,7 +927,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -951,7 +951,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) diff --git a/logger/logger.go b/logger/logger.go index 6bbbd9b..f2edfe3 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -124,9 +124,9 @@ func Err(err error) *zerolog.Event { return Logger.Err(err) } -// Ctx returns a DBSQLLogger from the provided context. If no logger is found, -// the default logger is returned. -func Ctx(ctx context.Context) *DBSQLLogger { +// FromContext returns a DBSQLLogger from the provided context. If no logger is +// found, the default logger is returned. +func FromContext(ctx context.Context) *DBSQLLogger { l := zerolog.Ctx(ctx) if l == zerolog.DefaultContextLogger { return Logger @@ -136,12 +136,15 @@ func Ctx(ctx context.Context) *DBSQLLogger { // AddContext sets connectionId, correlationId, and queryId as fields on the provided logger. func AddContext(l *DBSQLLogger, connectionId string, correlationId string, queryId string) *DBSQLLogger { + if l == nil { + l = Logger + } return &DBSQLLogger{l.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} } // WithContext sets connectionId, correlationId, and queryId to be used as fields. func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger { - return AddContext(Logger, connectionId, correlationId, queryId) + return AddContext(nil, connectionId, correlationId, queryId) } // Track is a convenience function to track time spent