diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index d6b7c8819a71..0f19bea9122b 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//pkg/security/certmgr", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgwirebase", + "//pkg/util/cache", "//pkg/util/grpcutil", "//pkg/util/httputil", "//pkg/util/log", diff --git a/pkg/ccl/sqlproxyccl/acl/cidr_ranges.go b/pkg/ccl/sqlproxyccl/acl/cidr_ranges.go index 8dda78501ab8..ddeb400563f4 100644 --- a/pkg/ccl/sqlproxyccl/acl/cidr_ranges.go +++ b/pkg/ccl/sqlproxyccl/acl/cidr_ranges.go @@ -50,11 +50,6 @@ func (p *CIDRRanges) CheckConnection(ctx context.Context, conn ConnectionTags) e } } - // By default, connections are rejected if no ranges match the connection's - // IP. - return errors.Newf( - "connection to '%s' denied: cluster does not allow public connections from IP %s", - conn.TenantID.String(), - conn.IP, - ) + // By default, connections are rejected if no ranges match the connection's IP. + return errors.Newf("cluster does not allow public connections from IP %s", conn.IP) } diff --git a/pkg/ccl/sqlproxyccl/acl/cidr_ranges_test.go b/pkg/ccl/sqlproxyccl/acl/cidr_ranges_test.go index be620dd1c382..b50a3265d088 100644 --- a/pkg/ccl/sqlproxyccl/acl/cidr_ranges_test.go +++ b/pkg/ccl/sqlproxyccl/acl/cidr_ranges_test.go @@ -63,7 +63,7 @@ func TestCIDRRanges(t *testing.T) { }, } err := p.CheckConnection(ctx, makeConn("")) - require.EqualError(t, err, "connection to '42' denied: cluster does not allow public connections from IP 127.0.0.1") + require.EqualError(t, err, "cluster does not allow public connections from IP 127.0.0.1") }) t.Run("default behavior if no entries", func(t *testing.T) { @@ -75,7 +75,7 @@ func TestCIDRRanges(t *testing.T) { }, } err := p.CheckConnection(ctx, makeConn("")) - require.EqualError(t, err, "connection to '42' denied: cluster does not allow public connections from IP 127.0.0.1") + require.EqualError(t, err, "cluster does not allow public connections from IP 127.0.0.1") }) t.Run("good public connection", func(t *testing.T) { diff --git a/pkg/ccl/sqlproxyccl/acl/private_endpoints.go b/pkg/ccl/sqlproxyccl/acl/private_endpoints.go index d79d986d357a..724518871e6b 100644 --- a/pkg/ccl/sqlproxyccl/acl/private_endpoints.go +++ b/pkg/ccl/sqlproxyccl/acl/private_endpoints.go @@ -50,11 +50,7 @@ func (p *PrivateEndpoints) CheckConnection(ctx context.Context, conn ConnectionT // By default, connections are rejected if no endpoints match the // connection's endpoint ID. - return errors.Newf( - "connection to '%s' denied: cluster does not allow private connections from endpoint '%s'", - conn.TenantID.String(), - conn.EndpointID, - ) + return errors.Newf("cluster does not allow private connections from endpoint '%s'", conn.EndpointID) } // FindPrivateEndpointID looks for the endpoint identifier within the connection diff --git a/pkg/ccl/sqlproxyccl/acl/private_endpoints_test.go b/pkg/ccl/sqlproxyccl/acl/private_endpoints_test.go index 0b252ccb4bc3..e410863871bd 100644 --- a/pkg/ccl/sqlproxyccl/acl/private_endpoints_test.go +++ b/pkg/ccl/sqlproxyccl/acl/private_endpoints_test.go @@ -66,7 +66,7 @@ func TestPrivateEndpoints(t *testing.T) { }, } err := p.CheckConnection(ctx, makeConn("bar")) - require.EqualError(t, err, "connection to '42' denied: cluster does not allow private connections from endpoint 'bar'") + require.EqualError(t, err, "cluster does not allow private connections from endpoint 'bar'") }) t.Run("default behavior if no entries", func(t *testing.T) { @@ -78,7 +78,7 @@ func TestPrivateEndpoints(t *testing.T) { }, } err := p.CheckConnection(ctx, makeConn("bar")) - require.EqualError(t, err, "connection to '42' denied: cluster does not allow private connections from endpoint 'bar'") + require.EqualError(t, err, "cluster does not allow private connections from endpoint 'bar'") }) t.Run("good private connection", func(t *testing.T) { diff --git a/pkg/ccl/sqlproxyccl/acl/watcher_test.go b/pkg/ccl/sqlproxyccl/acl/watcher_test.go index b491ca19bec5..634a062de1c9 100644 --- a/pkg/ccl/sqlproxyccl/acl/watcher_test.go +++ b/pkg/ccl/sqlproxyccl/acl/watcher_test.go @@ -156,7 +156,7 @@ func TestACLWatcher(t *testing.T) { } remove, err := watcher.ListenForDenied(ctx, connection, noError(t)) - require.EqualError(t, err, "connection to '10' denied: cluster does not allow private connections from endpoint 'random-connection'") + require.EqualError(t, err, "cluster does not allow private connections from endpoint 'random-connection'") require.Nil(t, remove) }) @@ -171,7 +171,7 @@ func TestACLWatcher(t *testing.T) { } remove, err := watcher.ListenForDenied(ctx, connection, noError(t)) - require.EqualError(t, err, "connection to '10' denied: cluster does not allow public connections from IP 127.0.0.1") + require.EqualError(t, err, "cluster does not allow public connections from IP 127.0.0.1") require.Nil(t, remove) }) @@ -283,7 +283,7 @@ func TestACLWatcher(t *testing.T) { // Emit the same item. c <- pe - require.EqualError(t, <-errorChan, "connection to '10' denied: cluster does not allow private connections from endpoint 'cockroachdb'") + require.EqualError(t, <-errorChan, "cluster does not allow private connections from endpoint 'cockroachdb'") }) t.Run("connection is denied by update to public cidr ranges", func(t *testing.T) { @@ -334,7 +334,7 @@ func TestACLWatcher(t *testing.T) { // Emit the same item. c <- pcr - require.EqualError(t, <-errorChan, "connection to '10' denied: cluster does not allow public connections from IP 1.1.2.2") + require.EqualError(t, <-errorChan, "cluster does not allow public connections from IP 1.1.2.2") }) t.Run("unregister removes listeners", func(t *testing.T) { diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 839da7bca925..1307fe0725a6 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -104,7 +104,8 @@ var authenticate = func( case *pgproto3.AuthenticationOk: throttleError := throttleHook(throttler.AttemptOK) if throttleError != nil { - if err = feSend(toPgError(throttleError)); err != nil { + // Send a user-facing error. + if err = feSend(toPgError(authThrottledError)); err != nil { return nil, err } return nil, throttleError @@ -125,7 +126,8 @@ var authenticate = func( throttleError = throttleHook(throttler.AttemptInvalidCredentials) } if throttleError != nil { - if err = feSend(toPgError(throttleError)); err != nil { + // Send a user-facing error. + if err = feSend(toPgError(authThrottledError)); err != nil { return nil, err } return nil, throttleError diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index e1314b2e7c7f..1676b272cc45 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -13,6 +13,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/testutilsccl" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -140,13 +141,19 @@ func TestAuthenticateThrottled(t *testing.T) { go server(t, sqlServer) go client(t, sqlClient) - _, err := authenticate(proxyToClient, proxyToServer, nil, /* proxyBackendKeyData */ + // The error returned from authenticate should be different from the error + // received at the client. + _, err := authenticate( + proxyToClient, + proxyToServer, + nil, /* proxyBackendKeyData */ func(status throttler.AttemptStatus) error { require.Equal(t, throttler.AttemptInvalidCredentials, status) - return authThrottledError - }) + return errors.New("request denied") + }, + ) require.Error(t, err) - require.Contains(t, err.Error(), "too many failed authentication attempts") + require.Contains(t, err.Error(), "request denied") proxyToServer.Close() proxyToClient.Close() diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index 7e451e7bcbce..a3153974cd8a 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -43,8 +43,8 @@ const ( // received from the client after TLS negotiation. codeUnexpectedStartupMessage - // codeParamsRoutingFailed indicates an error choosing a backend address based - // on the client's session parameters. + // codeParamsRoutingFailed indicates an error choosing a backend address + // based on the client's session parameters. codeParamsRoutingFailed // codeBackendDialFailed indicates an error establishing a connection @@ -59,8 +59,10 @@ const ( // (with a connection error) while in a session with backend SQL server. codeClientDisconnected - // codeProxyRefusedConnection indicates that the proxy refused the connection - // request due to high load or too many connection attempts. + // codeProxyRefusedConnection indicates that the proxy refused the + // connection request due to too many invalid connection attempts, or + // because the incoming connection session does not match the ACL rules + // for the cluster. codeProxyRefusedConnection // codeExpiredClientConnection indicates that proxy connection to the client @@ -68,15 +70,18 @@ const ( codeExpiredClientConnection // codeUnavailable indicates that the backend SQL server exists but is not - // accepting connections. For example, a tenant cluster that has maxPods set to 0. + // accepting connections. For example, a tenant cluster that has maxPods + // set to 0. codeUnavailable ) // errWithCode combines an error with one of the above codes to ease // the processing of the errors. +// // This follows the same pattern used by cockroachdb/errors that allows // decorating errors with additional information. Check WithStack, WithHint, // WithDetail etc. +// // By using the pattern, the decorations are chained and allow searching and // extracting later on. See getErrorCode bellow. type errWithCode struct { diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index f5a2a9171515..e40e23af1d66 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -53,6 +53,11 @@ var ( // smaller than 100 characters. We don't perform an exact match here for // more flexibility. clusterNameRegex = regexp.MustCompile("^[a-z0-9][a-z0-9-]{4,98}[a-z0-9]$") + + // highFreqErrorMarker is a marker that indicates that a particular error + // is of high-frequency and should be throttled accordingly. Use with + // errors.Mark and errors.Is. + highFreqErrorMarker = errors.New("high-frequency error") ) const ( @@ -380,6 +385,19 @@ func (handler *proxyHandler) handle( return errors.Wrap(err, "extracting cluster identifier") } + // Add optional request tags so that callers can provide a better context + // for errors. This is best-effort only, and will be applied if the incoming + // context includes a requestTags object. + reqTags := requestTagsFromContext(ctx) + + // Tenant ID is always valid in the request. We will defer the addition of + // cluster name into the logtags after validating the connection since there + // is a possibility where the cluster name won't match. + ctx = logtags.AddTag(ctx, "tenant", tenID.String()) + if reqTags != nil { + reqTags["tenant"] = tenID.String() + } + // Validate the incoming connection and ensure that the cluster name // matches the tenant's. This avoids malicious actors from attempting to // connect to the cluster using just the tenant ID. @@ -389,16 +407,13 @@ func (handler *proxyHandler) handle( return err } - // Only add logtags after validating the connection. If the connection isn't + // Add cluster name after validating the connection. If the connection isn't // validated, clusterName may not match the tenant ID, and this could cause // confusion when analyzing logs. ctx = logtags.AddTag(ctx, "cluster", clusterName) - ctx = logtags.AddTag(ctx, "tenant", tenID) - - // Add request tags so that callers can provide a better context for errors. - reqTags := requestTagsFromContext(ctx) - reqTags["cluster"] = clusterName - reqTags["tenant"] = tenID + if reqTags != nil { + reqTags["cluster"] = clusterName + } // Use an empty string as the default port as we only care about the // correctly parsing the IP address here. @@ -418,8 +433,15 @@ func (handler *proxyHandler) handle( EndpointID: endpointID, }, func(err error) { + // Whenever a connection expires, it will be closed by the proxy + // once the ACL watcher fires. The client will observe that the + // connection has been closed abruptly. One UX enhancement that + // could be made here is to send an ErrorResponse on the next pgwire + // message boundary to the client so that users know why their + // connections were closed. Implementing this will require + // communicating with the forwarder (e.g. "sendErrorSafelyAndClose"). err = withCode( - errors.Wrap(err, "connection blocked by access control list"), + errors.Wrap(err, "connection no longer allowed by access control list"), codeExpiredClientConnection, ) select { @@ -434,11 +456,20 @@ func (handler *proxyHandler) handle( // "connection refused" error. The next time they connect, they will // get a "not found" error. // - // TODO(jaylim-crl): We can enrich this error with the proper reason on - // why they were refused (e.g. IP allowlist, or private endpoints). + // We will return a generic "connection refused" error in this case. + // From a security perspective, it's preferable not to disclose specific + // details -- similar to how we use "invalid username or password" + // rather than just "invalid password". This does come with its own + // drawbacks (e.g. users may not know why), but that's fine, since they + // should have all the necessary information needed to debug ths issue. clientErr := withCode(errors.New("connection refused"), codeProxyRefusedConnection) updateMetricsAndSendErrToClient(clientErr, fe.Conn, handler.metrics) - return errors.Wrap(err, "connection blocked by access control list") + + // Users often misconfigure their allowed CIDR ranges or private + // endpoints, leading to excessive retry traffic. We will mark it as + // a high-frequency error so that callers that log errors could + // implement rate limiting accordingly. + return errors.Mark(errors.Wrap(err, "connection blocked by access control list"), highFreqErrorMarker) } defer removeListener() @@ -447,13 +478,12 @@ func (handler *proxyHandler) handle( if err != nil { clientErr := authThrottledError updateMetricsAndSendErrToClient(clientErr, fe.Conn, handler.metrics) + // The throttle service is used to rate limit invalid login attempts // from IP addresses, and it is commonly prone to generating excessive - // traffic in practice. Due to that, we'll return a nil here to prevent - // callers from logging this request. However, LoginCheck itself - // periodically logs an error when such requests are rate limited, so - // we won't miss any signals by doing this. - return nil //nolint:returnerrcheck + // traffic in practice. Due to that, we'll mark it as a high-frequency + // error to rate-limit logs. + return errors.Mark(err, highFreqErrorMarker) } connector := &connector{ @@ -482,13 +512,9 @@ func (handler *proxyHandler) handle( crdbConn, sentToClient, err := connector.OpenTenantConnWithAuth(ctx, f, fe.Conn, func(status throttler.AttemptStatus) error { - if err := handler.throttleService.ReportAttempt( - ctx, throttleTags, throttleTime, status, - ); err != nil { - // We have to log here because errors returned by this closure - // will be sent to the client. - log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error()) - return authThrottledError + err := handler.throttleService.ReportAttempt(ctx, throttleTags, throttleTime, status) + if err != nil { + return errors.Mark(err, highFreqErrorMarker) } return nil }, @@ -579,6 +605,7 @@ func (handler *proxyHandler) validateConnection( if err != nil && status.Code(err) != codes.NotFound { return err } + if err == nil { if tenant.ClusterName == "" || tenant.ClusterName == clusterName { return nil @@ -590,10 +617,13 @@ func (handler *proxyHandler) validateConnection( tenant.ClusterName, ) } - return withCode( - errors.Newf("cluster %s-%d not found", clusterName, tenantID.ToUint64()), - codeParamsRoutingFailed, - ) + + // The codes.NotFound error is a common occurrence, as we have often + // observed situations where a user deletes their tenant, but their + // application continues running. + clientErr := errors.Newf("cluster %s-%d not found", clusterName, tenantID.ToUint64()) + clientErr = errors.Mark(clientErr, highFreqErrorMarker) + return withCode(clientErr, codeParamsRoutingFailed) } // handleCancelRequest handles a pgwire query cancel request by either diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 18dfc1f8248e..736fd0602286 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -100,10 +100,12 @@ func TestProxyHandler_ValidateConnection(t *testing.T) { t.Run("not found/no cluster name", func(t *testing.T) { err := s.handler.validateConnection(ctx, invalidTenantID, "") require.Regexp(t, "codeParamsRoutingFailed: cluster -99 not found", err.Error()) + require.True(t, errors.Is(err, highFreqErrorMarker)) }) t.Run("not found", func(t *testing.T) { err := s.handler.validateConnection(ctx, invalidTenantID, "foo-bar") require.Regexp(t, "codeParamsRoutingFailed: cluster foo-bar-99 not found", err.Error()) + require.True(t, errors.Is(err, highFreqErrorMarker)) }) t.Run("found/tenant without name", func(t *testing.T) { err := s.handler.validateConnection(ctx, tenantWithoutNameID, "foo-bar") @@ -116,10 +118,12 @@ func TestProxyHandler_ValidateConnection(t *testing.T) { t.Run("found/connection without name", func(t *testing.T) { err := s.handler.validateConnection(ctx, tenantID, "") require.Regexp(t, "codeParamsRoutingFailed: cluster -10 not found", err.Error()) + require.True(t, errors.Is(err, highFreqErrorMarker)) }) t.Run("found/tenant name mismatch", func(t *testing.T) { err := s.handler.validateConnection(ctx, tenantID, "foo-bar") require.Regexp(t, "codeParamsRoutingFailed: cluster foo-bar-10 not found", err.Error()) + require.True(t, errors.Is(err, highFreqErrorMarker)) }) // Stop the directory server. @@ -130,6 +134,7 @@ func TestProxyHandler_ValidateConnection(t *testing.T) { // Use a new tenant ID here to force GetTenant. err := s.handler.validateConnection(ctx, roachpb.MustMakeTenantID(100), "") require.Regexp(t, "directory server has not been started", err.Error()) + require.False(t, errors.Is(err, highFreqErrorMarker)) }) } diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index bdf0b863f5e7..06e4ac7dd300 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -13,11 +13,14 @@ import ( "net/http/pprof" "time" + "github.com/cockroachdb/cockroach/pkg/util/cache" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/httputil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "github.com/cockroachdb/logtags" @@ -25,9 +28,21 @@ import ( "github.com/prometheus/common/expfmt" ) -var ( - awaitNoConnectionsInterval = time.Minute -) +// Variable to be swapped in tests. +var awaitNoConnectionsInterval = time.Minute + +// maxErrorLogLimiterCacheSize defines the maximum cache size for the error log +// limiter. We set it to 1M to align with the cache size of the connection +// throttler (see pkg/ccl/sqlproxyccl/throttler). Based on testing, this +// corresponds to roughly 200-300MB of memory usage when the cache is fully +// utilized. However, in practice, since the limiter is only applied to +// high-frequency errors, the cache will typically store only a small number of +// entries (e.g., around 2-3MB for 10K entries). +const maxErrorLogLimiterCacheSize = 1e6 + +// errorLogLimiterDuration indicates how frequent an error should be logged +// for a given (ip, tenant ID) pair. +const errorLogLimiterDuration = 5 * time.Minute // Server is a TCP server that proxies SQL connections to a configurable // backend. It may also run an HTTP server to expose a health check and @@ -40,6 +55,23 @@ type Server struct { metricsRegistry *metric.Registry prometheusExporter metric.PrometheusExporter + + // errorLogLimiter is used to rate-limit the frequency at which *common* + // errors are logged. It ensures that repeated error messages for the same + // connection tag are logged at most once every 5 minutes (see + // errorLogLimiterDuration). + mu struct { + syncutil.Mutex + errorLogLimiter *cache.UnorderedCache // map[connTag]*log.EveryN + } +} + +// connTag represents the key for errorLogLimiter. +type connTag struct { + // ip is the IP address of the client. + ip string + // tenantID is the ID of the tenant database the client is connecting to. + tenantID string } // NewServer constructs a new proxy server and provisions metrics and health @@ -66,6 +98,15 @@ func NewServer(ctx context.Context, stopper *stop.Stopper, options ProxyOptions) prometheusExporter: metric.MakePrometheusExporter(), } + // Configure the error log limiter cache. + cacheConfig := cache.Config{ + Policy: cache.CacheLRU, + ShouldEvict: func(size int, key, value interface{}) bool { + return size > maxErrorLogLimiterCacheSize + }, + } + s.mu.errorLogLimiter = cache.NewUnorderedCache(cacheConfig) + // /_status/{healthz,vars} matches CRDB's healthcheck and metrics // endpoints. mux.HandleFunc("/_status/vars/", s.handleVars) @@ -287,9 +328,12 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, requireProxyProtoco err := s.handler.handle(ctx, conn, requireProxyProtocol) if err != nil && !errors.Is(err, context.Canceled) { + // Ensure that context is tagged with request tags which are + // populated at higher layers of the stack. for key, value := range reqTags { ctx = logtags.AddTag(ctx, key, value) } + // log.Infof automatically prints hints (one per line) that are // associated with the input error object. This causes // unnecessary log spam, especially when proxy hints are meant @@ -298,8 +342,10 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, requireProxyProtoco // // TODO(jaylim-crl): Ensure that handle does not return user // facing errors (i.e. one that contains hints). - errWithoutHints := errors.Newf("%s", err.Error()) // nolint:errwrap - log.Infof(ctx, "connection closed: %v", errWithoutHints) + if s.shouldLogError(ctx, err, conn, reqTags) { + errWithoutHints := errors.Newf("%s", err.Error()) // nolint:errwrap + log.Infof(ctx, "connection closed: %v", errWithoutHints) + } } }) if err != nil { @@ -349,6 +395,64 @@ func (s *Server) AwaitNoConnections(ctx context.Context) <-chan struct{} { return c } +// shouldLogError returns true if an error should be logged, or false otherwise. +// The goal is to only throttle high-frequency errors for every (IP, tenantID) +// pair. +func (s *Server) shouldLogError( + ctx context.Context, err error, conn net.Conn, reqTags map[string]interface{}, +) bool { + // Always log non high-frequency errors. + if !errors.Is(err, highFreqErrorMarker) { + return true + } + + // Request hasn't been populated with a tenant ID yet, so log them. + tenantID, hasTenantID := reqTags["tenant"] + if !hasTenantID { + return true + } + + // Tenant ID must be a string, or else there has to be a bug in the code. + // Instead of panicking, we'll skip throttling. + tenantIDStr, ok := tenantID.(string) + if !ok { + log.Errorf( + ctx, + "unexpected error: cannot extract tenant ID from request tags; found: %v", + tenantID, + ) + return true + } + + // Extract just the IP from the remote address. We'll log anyway if we get + // an error. This case cannot happen in practice since conn.RemoteAddr() + // will always return a valid IP. + ipAddr, _, err := addr.SplitHostPort(conn.RemoteAddr().String(), "") + if err != nil { + log.Errorf( + ctx, + "unexpected error: cannot extract remote IP from connection; found: %v", + conn.RemoteAddr().String(), + ) + return true + } + + s.mu.Lock() + defer s.mu.Unlock() + + var limiter *log.EveryN + key := connTag{ip: ipAddr, tenantID: tenantIDStr} + c, ok := s.mu.errorLogLimiter.Get(key) + if ok && c != nil { + limiter = c.(*log.EveryN) + } else { + e := log.Every(errorLogLimiterDuration) + limiter = &e + s.mu.errorLogLimiter.Add(key, limiter) + } + return limiter.ShouldLog() +} + // requestTagsContextKey is the type of a context.Value key used to carry the // request tags map in a context.Context object. type requestTagsContextKey struct{} diff --git a/pkg/ccl/sqlproxyccl/server_test.go b/pkg/ccl/sqlproxyccl/server_test.go index abc344677a08..05bbd9ad2072 100644 --- a/pkg/ccl/sqlproxyccl/server_test.go +++ b/pkg/ccl/sqlproxyccl/server_test.go @@ -8,6 +8,7 @@ package sqlproxyccl import ( "context" "io" + "net" "net/http" "net/http/httptest" "testing" @@ -17,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -97,3 +99,58 @@ func TestAwaitNoConnections(t *testing.T) { // Make sure we waited for the connection to be dropped. require.GreaterOrEqual(t, timeutil.Since(begin), waitTime) } + +func TestShouldLogError(t *testing.T) { + defer leaktest.AfterTest(t)() + testutilsccl.ServerlessOnly(t) + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + s, err := NewServer(ctx, stopper, ProxyOptions{}) + require.NoError(t, err) + + t.Run("no error", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, nil, nil, nil)) + }) + + t.Run("normal error", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, errors.New("foo"), nil, nil)) + }) + + highFreqErr := errors.Mark(errors.New("test error"), highFreqErrorMarker) + reqTags := make(map[string]interface{}) + + t.Run("no tenant ID", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, highFreqErr, nil, reqTags)) + }) + + reqTags["tenant"] = 42 + + t.Run("tenant ID not string", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, highFreqErr, nil, reqTags)) + }) + + reqTags["tenant"] = "42" + + t.Run("invalid connection", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, highFreqErr, &fakeTCPConn{}, reqTags)) + }) + + conn := &fakeTCPConn{remoteAddr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}}} + + t.Run("successful - tenant 42", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + require.False(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + require.False(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + }) + + // Using a different tenant should work independently. + reqTags["tenant"] = "100" + + t.Run("successful - tenant 100", func(t *testing.T) { + require.True(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + require.False(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + require.False(t, s.shouldLogError(ctx, highFreqErr, conn, reqTags)) + }) +} diff --git a/pkg/ccl/sqlproxyccl/throttler/local.go b/pkg/ccl/sqlproxyccl/throttler/local.go index b4fb0bb026cd..c0a8622c26f0 100644 --- a/pkg/ccl/sqlproxyccl/throttler/local.go +++ b/pkg/ccl/sqlproxyccl/throttler/local.go @@ -15,7 +15,10 @@ import ( "github.com/cockroachdb/errors" ) -var errRequestDenied = errors.New("request denied") +var ( + errRequestDeniedAtLogin = errors.New("throttler refused connection due to too many failed authentication attempts") + errRequestDeniedAfterAuth = errors.New("throttler refused connection after authentication") +) type timeNow func() time.Time @@ -98,11 +101,7 @@ func (s *localService) LoginCheck( now := s.clock() throttle := s.lockedGetThrottle(connection) if throttle != nil && throttle.isThrottled(now) { - if throttle.everyLog.ShouldLog() { - // ctx should include logtags about the connection. - log.Error(ctx, "throttler refused connection due to too many failed authentication attempts") - } - return now, errRequestDenied + return now, errRequestDeniedAtLogin } return now, nil } @@ -120,7 +119,7 @@ func (s *localService) ReportAttempt( } if throttle.isThrottled(throttleTime) { - return errRequestDenied + return errRequestDeniedAfterAuth } switch { diff --git a/pkg/ccl/sqlproxyccl/throttler/local_test.go b/pkg/ccl/sqlproxyccl/throttler/local_test.go index 4925111a2348..6e133c17f4ec 100644 --- a/pkg/ccl/sqlproxyccl/throttler/local_test.go +++ b/pkg/ccl/sqlproxyccl/throttler/local_test.go @@ -137,7 +137,9 @@ func TestRacingRequests(t *testing.T) { nextTime := l.nextTime for _, status := range []AttemptStatus{AttemptOK, AttemptInvalidCredentials} { - require.Error(t, throttle.ReportAttempt(ctx, connection, throttleTime, status)) + err := throttle.ReportAttempt(ctx, connection, throttleTime, status) + require.Error(t, err) + require.Regexp(t, "throttler refused connection", err.Error()) // Verify the throttled report has no affect on limiter state. l := throttle.lockedGetThrottle(connection) diff --git a/pkg/ccl/sqlproxyccl/throttler/throttle.go b/pkg/ccl/sqlproxyccl/throttler/throttle.go index 16766ed9ad40..bb9324fae302 100644 --- a/pkg/ccl/sqlproxyccl/throttler/throttle.go +++ b/pkg/ccl/sqlproxyccl/throttler/throttle.go @@ -5,19 +5,10 @@ package throttler -import ( - "time" +import "time" - "github.com/cockroachdb/cockroach/pkg/util/log" -) - -const ( - // throttleDisabled is a sentinel value used to disable the throttle. - throttleDisabled = time.Duration(0) - // throttleLogErrorDuration indicates how frequent the throttle error should - // be logged for a given throttle instance. - throttleLogErrorDuration = 5 * time.Minute -) +// throttleDisabled is a sentinel value used to disable the throttle. +const throttleDisabled = time.Duration(0) type throttle struct { // The next time an operation blocked by this throttle can proceed. @@ -25,15 +16,12 @@ type throttle struct { // The amount of backoff to introduce the next time the throttle // is triggered. Setting nextBackoff to zero disables the throttle. nextBackoff time.Duration - // everyLog controls how frequent the throttle error should be logged. - everyLog log.EveryN } func newThrottle(initialBackoff time.Duration) *throttle { return &throttle{ nextTime: time.Time{}, nextBackoff: initialBackoff, - everyLog: log.Every(throttleLogErrorDuration), } }