diff --git a/driver.go b/driver.go index 6f59a70..f5f91ca 100644 --- a/driver.go +++ b/driver.go @@ -208,7 +208,7 @@ func (d otDriver) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - return makeConn(c, d.connConfig), nil + return wrapConn(c, d.connConfig), nil } func (d otDriver) Driver() driver.Driver { diff --git a/driver_test.go b/driver_test.go index ee86bc0..1dec8e5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -162,6 +162,40 @@ func TestWrap_DriverContext_ConnectError(t *testing.T) { assert.Equal(t, expectedError, err) } +func TestWrap_DriverContext_ConnectNamedValueChecker(t *testing.T) { + t.Parallel() + + _, m, err := sqlmock.New() + require.NoError(t, err) + + parent := struct { + driver.Driver + driver.DriverContext + }{ + DriverContext: driverOpenConnectorFunc(func(name string) (driver.Connector, error) { + return struct { + driverDriverFunc + driverConnectFunc + driverNamedValueCheckerFunc + }{ + driverConnectFunc: func(ctx context.Context) (driver.Conn, error) { + return m.(driver.Conn), nil + }, + }, nil + }), + } + + drv := otelsql.Wrap(parent).(driver.DriverContext) // nolint: errcheck + + connector, err := drv.OpenConnector("") + require.NoError(t, err) + + conn, err := connector.Connect(context.Background()) + require.NoError(t, err) + + assert.Implements(t, (*driver.NamedValueChecker)(nil), conn) +} + func TestWrap_DriverContext_CloseBeforeOpenConnector(t *testing.T) { t.Parallel() @@ -2758,6 +2792,12 @@ func (f driverDriverFunc) Driver() driver.Driver { return f() } +type driverNamedValueCheckerFunc func(*driver.NamedValue) error + +func (f driverNamedValueCheckerFunc) CheckNamedValue(nv *driver.NamedValue) error { + return f(nv) +} + type testError string func (e testError) Error() string {