Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the context logger, if available. #180

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
stagingErr := c.execStagingOperation(exStmtResp, ctx)

if exStmtResp != nil && exStmtResp.OperationHandle != nil {
// we have an operation id so update the logger
log, _ := client.LoggerAndContext(ctx, exStmtResp)

// since we have an operation handle we can close the operation if necessary
alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
Expand Down Expand Up @@ -167,7 +170,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}

corrId := driverctx.CorrelationIdFromContext(ctx)
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx))

return rows, err

Expand Down Expand Up @@ -340,7 +343,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver

func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
corrId := driverctx.CorrelationIdFromContext(ctx)
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
log := logger.AddContext(logger.FromContext(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
var statusResp *cli_service.TGetOperationStatusResp
ctx = driverctx.NewContextWithConnId(ctx, c.id)
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
Expand Down Expand Up @@ -559,7 +562,7 @@ func (c *conn) execStagingOperation(
}

if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx))
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
}
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
client: tclient,
session: session,
}
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")
log := logger.AddContext(logger.FromContext(ctx), conn.id, driverctx.CorrelationIdFromContext(ctx), "")

log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)

Expand Down
12 changes: 12 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ The result log may look like this:

{"level":"debug","connId":"01ed6545-5669-1ec7-8c7e-6d8a1ea0ab16","corrId":"workflow-example","queryId":"01ed6545-57cc-188a-bfc5-d9c0eaf8e189","time":1668558402,"message":"Run Main elapsed time: 1.298712292s"}

You may customize the log by passing it using Zerolog's context support. This allows customziation of the output, as well as inclusion of additionl metadata.

For example,

log := zerolog.New(DefaultLogOutput).With("service_id", "workflow-example")).Logger()
ctx = log.WithContext(context.Background())
...
db, err := sql.Open("databricks", "<dsn_string>")
...
rows, err := db.QueryContext(ctx, `select * from sometable`)
...

# Programmatically Retrieving Connection and Query Id

Use the driverctx package under driverctx/ctx.go to add callbacks to the query context to receive the connection id and query id.
Expand Down
2 changes: 1 addition & 1 deletion internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func LoggerAndContext(ctx context.Context, c any) (*logger.DBSQLLogger, context.
queryId = guidFromHasOpHandle(c)
ctx = driverctx.NewContextWithQueryId(ctx, queryId)
}
log := logger.WithContext(connId, corrId, queryId)
log := logger.AddContext(logger.FromContext(ctx), connId, corrId, queryId)

return log, ctx
}
Expand Down
6 changes: 3 additions & 3 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ func NewRows(
client cli_service.TCLIService,
config *config.Config,
directResults *cli_service.TSparkDirectResults,
logger *dbsqllog.DBSQLLogger,
) (driver.Rows, dbsqlerr.DBError) {

var logger *dbsqllog.DBSQLLogger
var ctx context.Context
if opHandle != nil {
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
logger = dbsqllog.AddContext(logger, connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
} else {
logger = dbsqllog.WithContext(connId, correlationId, "")
logger = dbsqllog.AddContext(logger, connId, correlationId, "")
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
}

Expand Down
16 changes: 8 additions & 8 deletions internal/rows/rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ func TestColumnsWithDirectResults(t *testing.T) {

client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount)

d, err := NewRows("", "", nil, client, nil, nil)
d, err := NewRows("", "", nil, client, nil, nil, nil)
assert.Nil(t, err)

rowSet := d.(*rows)
Expand Down Expand Up @@ -708,7 +708,7 @@ func TestRowsCloseOptimization(t *testing.T) {
}

opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}}
rowSet, _ := NewRows("", "", opHandle, client, nil, nil)
rowSet, _ := NewRows("", "", opHandle, client, nil, nil, nil)

// rowSet has no direct results calling Close should result in call to client to close operation
err := rowSet.Close()
Expand All @@ -721,7 +721,7 @@ func TestRowsCloseOptimization(t *testing.T) {
ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}},
}
closeCount = 0
rowSet, _ = NewRows("", "", opHandle, client, nil, directResults)
rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil)
err = rowSet.Close()
assert.Nil(t, err, "rows.Close should not throw an error")
assert.Equal(t, 1, closeCount)
Expand All @@ -734,7 +734,7 @@ func TestRowsCloseOptimization(t *testing.T) {
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}},
ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}},
}
rowSet, _ = NewRows("", "", opHandle, client, nil, directResults)
rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil)
err = rowSet.Close()
assert.Nil(t, err, "rows.Close should not throw an error")
assert.Equal(t, 0, closeCount)
Expand Down Expand Up @@ -799,7 +799,7 @@ func TestGetArrowBatches(t *testing.T) {

client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
cfg := config.WithDefaults()
rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults)
rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.Rows)
Expand Down Expand Up @@ -869,7 +869,7 @@ func TestGetArrowBatches(t *testing.T) {

client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3})
cfg := config.WithDefaults()
rows, err := NewRows("connId", "corrId", nil, client, cfg, nil)
rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.Rows)
Expand Down Expand Up @@ -927,7 +927,7 @@ func TestGetArrowBatches(t *testing.T) {

client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1})
cfg := config.WithDefaults()
rows, err := NewRows("connId", "corrId", nil, client, cfg, nil)
rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.Rows)
Expand All @@ -951,7 +951,7 @@ func TestGetArrowBatches(t *testing.T) {

client := getSimpleClient([]cli_service.TFetchResultsResp{})
cfg := config.WithDefaults()
rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults)
rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.Rows)
Expand Down
21 changes: 20 additions & 1 deletion logger/logger.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logger

import (
"context"
"io"
"os"
"runtime"
Expand Down Expand Up @@ -123,9 +124,27 @@ func Err(err error) *zerolog.Event {
return Logger.Err(err)
}

// FromContext returns a DBSQLLogger from the provided context. If no logger is
// found, the default logger is returned.
func FromContext(ctx context.Context) *DBSQLLogger {
l := zerolog.Ctx(ctx)
if l == zerolog.DefaultContextLogger {
return Logger
}
return &DBSQLLogger{*l}
}

// AddContext sets connectionId, correlationId, and queryId as fields on the provided logger.
func AddContext(l *DBSQLLogger, connectionId string, correlationId string, queryId string) *DBSQLLogger {
if l == nil {
l = Logger
}
return &DBSQLLogger{l.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()}
}

// WithContext sets connectionId, correlationId, and queryId to be used as fields.
func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger {
return &DBSQLLogger{Logger.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()}
return AddContext(nil, connectionId, correlationId, queryId)
}

// Track is a convenience function to track time spent
Expand Down
Loading