diff --git a/connection.go b/connection.go index d495939..95727ee 100644 --- a/connection.go +++ b/connection.go @@ -62,7 +62,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e // Ping attempts to verify that the server is accessible. // Returns ErrBadConn if ping fails and consequently DB.Ping will remove the conn from the pool. func (c *conn) Ping(ctx context.Context) error { - log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") + log := logger.AddContext(logger.Ctx(ctx), c.id, driverctx.CorrelationIdFromContext(ctx), "") ctx = driverctx.NewContextWithConnId(ctx, c.id) ctx1, cancel := context.WithTimeout(ctx, c.cfg.PingTimeout) defer cancel() @@ -92,7 +92,7 @@ func (c *conn) IsValid() bool { // Statement ExecContext is the same as connection ExecContext func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") + log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, "") msg, start := logger.Track("ExecContext") defer log.Duration(msg, start) @@ -104,7 +104,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name if exStmtResp != nil && exStmtResp.OperationHandle != nil { // we have an operation id so update the logger - log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) + log = logger.AddContext(logger.Ctx(ctx), c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) // since we have an operation handle we can close the operation if necessary alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil @@ -135,7 +135,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name // Statement QueryContext is the same as connection QueryContext func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") + log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, "") msg, start := log.Track("QueryContext") ctx = driverctx.NewContextWithConnId(ctx, c.id) @@ -147,7 +147,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam exStmtResp, _, err := c.runQuery(ctx, query, args) if exStmtResp != nil && exStmtResp.OperationHandle != nil { - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) + log = logger.AddContext(logger.Ctx(ctx), c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) } defer log.Duration(msg, start) @@ -165,7 +165,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { - log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") + log := logger.AddContext(logger.Ctx(ctx), c.id, driverctx.CorrelationIdFromContext(ctx), "") // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return exStmtResp, err := c.executeStatement(ctx, query, args) @@ -175,7 +175,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa } opHandle := exStmtResp.OperationHandle if opHandle != nil && opHandle.OperationId != nil { - log = logger.WithContext( + log = logger.AddContext(logger.Ctx(ctx), c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID), ) @@ -259,7 +259,7 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") + log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, "") req := cli_service.TExecuteStatementReq{ SessionHandle: c.session.SessionHandle, @@ -311,7 +311,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) + log := logger.AddContext(logger.Ctx(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp ctx = driverctx.NewContextWithConnId(ctx, c.id) newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) diff --git a/logger/logger.go b/logger/logger.go index e048eaa..c2a7711 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "context" "io" "os" "runtime" @@ -128,9 +129,24 @@ func Err(err error) *zerolog.Event { return Logger.Err(err) } +// Ctx returns a DBSQLLogger from the provided context. If no logger is found, +// the default logger is returned. +func Ctx(ctx context.Context) *DBSQLLogger { + l := zerolog.Ctx(ctx) + if l == zerolog.DefaultContextLogger { + return Logger + } + return &DBSQLLogger{*l} +} + +// AddContext sets connectionId, correlationId, and queryId as fields on the provided logger. +func AddContext(l *DBSQLLogger, connectionId string, correlationId string, queryId string) *DBSQLLogger { + return &DBSQLLogger{l.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} +} + // WithContext sets connectionId, correlationId, and queryId to be used as fields. func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger { - return &DBSQLLogger{Logger.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} + return AddContext(Logger, connectionId, correlationId, queryId) } // Track is a convenience function to track time spent