diff --git a/pkg/sql/pgwire/pgwire_test.go b/pkg/sql/pgwire/pgwire_test.go index 0d7080d74c42..2755a7811b28 100644 --- a/pkg/sql/pgwire/pgwire_test.go +++ b/pkg/sql/pgwire/pgwire_test.go @@ -1867,33 +1867,41 @@ var _ pgx.Logger = pgxTestLogger{} func TestCancelRequest(t *testing.T) { defer leaktest.AfterTest(t)() - params := base.TestServerArgs{Insecure: true} - s, _, _ := serverutils.StartServer(t, params) + testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) { + params := base.TestServerArgs{Insecure: insecure} + s, _, _ := serverutils.StartServer(t, params) - ctx := context.TODO() - defer s.Stopper().Stop(ctx) + ctx := context.TODO() + defer s.Stopper().Stop(ctx) - var d net.Dialer - conn, err := d.DialContext(ctx, "tcp", s.ServingSQLAddr()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", s.ServingSQLAddr()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() - fe, err := pgproto3.NewFrontend(conn, conn) - if err != nil { - t.Fatal(err) - } - const versionCancel = 80877102 - if err := fe.Send(&pgproto3.StartupMessage{ProtocolVersion: versionCancel}); err != nil { - t.Fatal(err) - } - if _, err := fe.Receive(); err != io.EOF { - t.Fatalf("unexpected: %v", err) - } - if count := telemetry.GetRawFeatureCounts()["pgwire.unimplemented.cancel_request"]; count != 1 { - t.Fatalf("expected 1 cancel request, got %d", count) - } + // Reset telemetry so we get a deterministic count below. + _ = telemetry.GetFeatureCounts(telemetry.Raw, telemetry.ResetCounts) + + fe, err := pgproto3.NewFrontend(conn, conn) + if err != nil { + t.Fatal(err) + } + // versionCancel is the special code sent as header for cancel requests. + // See: https://www.postgresql.org/docs/current/protocol-message-formats.html + // and the explanation in server.go. + const versionCancel = 80877102 + if err := fe.Send(&pgproto3.StartupMessage{ProtocolVersion: versionCancel}); err != nil { + t.Fatal(err) + } + if _, err := fe.Receive(); err != io.EOF { + t.Fatalf("unexpected: %v", err) + } + if count := telemetry.GetRawFeatureCounts()["pgwire.unimplemented.cancel_request"]; count != 1 { + t.Fatalf("expected 1 cancel request, got %d", count) + } + }) } func TestFailPrepareFailsTxn(t *testing.T) { diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 414e7a785979..0d59b6668194 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -488,6 +488,17 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket return err } + if version == versionCancel { + // The cancel message is rather peculiar: it is sent without + // authentication, always over an unencrypted channel. + // + // Since we don't support this, close the door in the client's + // face. Make a note of that use in telemetry. + telemetry.Inc(sqltelemetry.CancelRequestCounter) + _ = conn.Close() + return nil + } + // If the server is shutting down, terminate the connection early. if draining { return s.sendErr(ctx, conn, newAdminShutdownErr(ErrDrainingNewConn)) @@ -512,14 +523,6 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket // What does the client want to do? switch version { - case versionCancel: - // If the client is really issuing a cancel request, close the door - // in their face (we don't support it yet). Make a note of that use - // in telemetry. - telemetry.Inc(sqltelemetry.CancelRequestCounter) - _ = conn.Close() - return nil - case version30: // Normal SQL connection. Proceed normally below.