diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go index 263195fb9c40..8329efecc8ae 100644 --- a/pkg/ccl/sqlproxyccl/connector.go +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -251,18 +251,20 @@ func (c *connector) dialTenantCluster( // Report the failure to the directory cache so that it can // refresh any stale information that may have caused the // problem. - if err = reportFailureToDirectoryCache( + if reportErr := reportFailureToDirectoryCache( ctx, c.TenantID, serverAssignment.Addr(), c.DirectoryCache, - ); err != nil { + ); reportErr != nil { reportFailureErrs++ if reportFailureErr.ShouldLog() { log.Ops.Errorf(ctx, "report failure (%d errors skipped): %v", reportFailureErrs, - err, + reportErr, ) reportFailureErrs = 0 } + // nolint:errwrap + err = errors.Wrapf(err, "reporting failure: %s", reportErr.Error()) } continue } @@ -275,7 +277,14 @@ func (c *connector) dialTenantCluster( // a bounded number of times. In our case, since we retry infinitely, the // only possibility is when ctx's Done channel is closed (which implies that // ctx.Err() != nil). - // + if err == nil || ctx.Err() == nil { + // nolint:errwrap + return nil, errors.AssertionFailedf( + "unexpected retry loop exit, err=%v, ctxErr=%v", + err, + ctx.Err(), + ) + } // If the error is already marked, just return that. if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { return nil, err diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go index a3684cf719bf..8a28955b71a6 100644 --- a/pkg/ccl/sqlproxyccl/connector_test.go +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -370,6 +370,90 @@ func TestConnector_dialTenantCluster(t *testing.T) { require.Nil(t, conn) }) + t.Run("context canceled after dial fails", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + c := &connector{ + TenantID: roachpb.MustMakeTenantID(42), + DialTenantLatency: metric.NewHistogram( + metaDialTenantLatency, time.Millisecond, metric.NetworkLatencyBuckets, + ), + DialTenantRetries: metric.NewCounter(metaDialTenantRetries), + } + dc := &testTenantDirectoryCache{} + c.DirectoryCache = dc + b, err := balancer.NewBalancer( + ctx, + stopper, + balancer.NewMetrics(), + c.DirectoryCache, + balancer.NoRebalanceLoop(), + ) + require.NoError(t, err) + c.Balancer = b + + var dialSQLServerCount int + c.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + return "127.0.0.10:42", nil + } + c.testingKnobs.dialSQLServer = func(serverAssignment *balancer.ServerAssignment) (net.Conn, error) { + require.Equal(t, serverAssignment.Addr(), "127.0.0.10:42") + dialSQLServerCount++ + + // Cancel context to trigger loop exit on next retry. + cancel() + return nil, markAsRetriableConnectorError(errors.New("bar")) + } + + var reportFailureFnCount int + + // Invoke dial tenant with a success to ReportFailure. + // --------------------------------------------------- + dc.reportFailureFn = func(fnCtx context.Context, tenantID roachpb.TenantID, addr string) error { + reportFailureFnCount++ + require.Equal(t, ctx, fnCtx) + require.Equal(t, c.TenantID, tenantID) + require.Equal(t, "127.0.0.10:42", addr) + return nil + } + conn, err := c.dialTenantCluster(ctx, nil /* requester */) + require.EqualError(t, err, "bar") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + // Assert existing calls. + require.Equal(t, 1, dialSQLServerCount) + require.Equal(t, 1, reportFailureFnCount) + require.Equal(t, c.DialTenantLatency.TotalCount(), int64(1)) + require.Equal(t, c.DialTenantRetries.Count(), int64(0)) + + // Invoke dial tenant with a failure to ReportFailure. Final error + // should include the secondary failure. + // --------------------------------------------------------------- + dc.reportFailureFn = func(fnCtx context.Context, tenantID roachpb.TenantID, addr string) error { + reportFailureFnCount++ + require.Equal(t, ctx, fnCtx) + require.Equal(t, c.TenantID, tenantID) + require.Equal(t, "127.0.0.10:42", addr) + return errors.New("failure to report") + } + conn, err = c.dialTenantCluster(ctx, nil /* requester */) + require.EqualError(t, err, "reporting failure: failure to report: bar") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + // Assert existing calls. + require.Equal(t, 2, dialSQLServerCount) + require.Equal(t, 2, reportFailureFnCount) + require.Equal(t, c.DialTenantLatency.TotalCount(), int64(2)) + require.Equal(t, c.DialTenantRetries.Count(), int64(0)) + }) + t.Run("non-transient error", func(t *testing.T) { // This is a short test, and is expected to finish within ms. ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second)