diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index a00894355d73..7b9b93b7d672 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -367,6 +367,7 @@ ALL_TESTS = [ "//pkg/util/metric:metric_test", "//pkg/util/mon:mon_test", "//pkg/util/netutil/addr:addr_test", + "//pkg/util/netutil:netutil_test", "//pkg/util/optional:optional_test", "//pkg/util/pretty:pretty_test", "//pkg/util/protoutil:protoutil_test", diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 7ea89d543978..c4600b690417 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -341,8 +341,8 @@ func (c *conn) serveImpl( } var terminateSeen bool - var authDone, ignoreUntilSync bool + var repeatedErrorCount int for { breakLoop, err := func() (bool, error) { typ, n, err := c.readBuf.ReadTypedMsg(&c.rd) @@ -482,11 +482,20 @@ func (c *conn) serveImpl( if err != nil { log.VEventf(ctx, 1, "pgwire: error processing message: %s", err) ignoreUntilSync = true - // If we can't read data because the connection was closed or the context - // was canceled (e.g. during authentication), then we should break. - if netutil.IsClosedConnection(err) || errors.Is(err, context.Canceled) { + repeatedErrorCount++ + const maxRepeatedErrorCount = 1 << 15 + // If we can't read data because of any one of the following conditions, + // then we should break: + // 1. the connection was closed. + // 2. the context was canceled (e.g. during authentication). + // 3. we reached an arbitrary threshold of repeated errors. + if netutil.IsClosedConnection(err) || + errors.Is(err, context.Canceled) || + repeatedErrorCount > maxRepeatedErrorCount { break } + } else { + repeatedErrorCount = 0 } if breakLoop { break diff --git a/pkg/util/netutil/BUILD.bazel b/pkg/util/netutil/BUILD.bazel index 450929af3977..5f7333e95997 100644 --- a/pkg/util/netutil/BUILD.bazel +++ b/pkg/util/netutil/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "netutil", @@ -16,3 +16,15 @@ go_library( "@org_golang_x_net//http2", ], ) + +go_test( + name = "netutil_test", + srcs = ["net_test.go"], + embed = [":netutil"], + deps = [ + "//pkg/util/contextutil", + "@com_github_cockroachdb_cmux//:cmux", + "@com_github_stretchr_testify//assert", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/pkg/util/netutil/net.go b/pkg/util/netutil/net.go index ffdfbb09323d..f729dee7d110 100644 --- a/pkg/util/netutil/net.go +++ b/pkg/util/netutil/net.go @@ -159,10 +159,13 @@ func (s *Server) ServeWith( } } -// IsClosedConnection returns true if err is cmux.ErrListenerClosed, -// grpc.ErrServerStopped, io.EOF, or the net package's errClosed. +// IsClosedConnection returns true if err is a non-temporary net.Error or is +// cmux.ErrListenerClosed, grpc.ErrServerStopped, io.EOF, or net.ErrClosed. func IsClosedConnection(err error) bool { - return errors.IsAny(err, cmux.ErrListenerClosed, grpc.ErrServerStopped, io.EOF) || + if netError := net.Error(nil); errors.As(err, &netError) { + return !netError.Temporary() + } + return errors.IsAny(err, cmux.ErrListenerClosed, grpc.ErrServerStopped, io.EOF, net.ErrClosed) || strings.Contains(err.Error(), "use of closed network connection") } diff --git a/pkg/util/netutil/net_test.go b/pkg/util/netutil/net_test.go new file mode 100644 index 000000000000..2cc492689313 --- /dev/null +++ b/pkg/util/netutil/net_test.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package netutil + +import ( + "fmt" + "io" + "net" + "syscall" + "testing" + + "github.com/cockroachdb/cmux" + "github.com/cockroachdb/cockroach/pkg/util/contextutil" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestIsClosedConnection(t *testing.T) { + for _, tc := range []struct { + err error + isClosedError bool + }{ + { + fmt.Errorf("an error"), + false, + }, + { + net.ErrClosed, + true, + }, + { + cmux.ErrListenerClosed, + true, + }, + { + grpc.ErrServerStopped, + true, + }, + { + io.EOF, + true, + }, + { + // TODO(rafi): should this be treated the same as EOF? + io.ErrUnexpectedEOF, + false, + }, + { + &net.AddrError{Err: "addr", Addr: "err"}, + true, + }, + { + syscall.ECONNRESET, + true, + }, + { + syscall.EADDRINUSE, + true, + }, + { + syscall.ECONNABORTED, + true, + }, + { + syscall.ECONNREFUSED, + true, + }, + { + syscall.EBADMSG, + true, + }, + { + syscall.EINTR, + false, + }, + { + &contextutil.TimeoutError{}, + false, + }, + } { + assert.Equalf(t, tc.isClosedError, IsClosedConnection(tc.err), + "expected %q to be evaluated as %v", tc.err, tc.isClosedError, + ) + } +}