diff --git a/pkg/server/server.go b/pkg/server/server.go index 58e8e660b12d..3bf0543347f1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -81,7 +81,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" "github.com/cockroachdb/redact" "github.com/cockroachdb/sentry-go" gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" @@ -1892,7 +1891,7 @@ func (s *sqlServer) startServeSQL( stopper.RunWorker(pgCtx, func(pgCtx context.Context) { netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, pgL, func(conn net.Conn) { - connCtx := logtags.AddTag(pgCtx, "client", conn.RemoteAddr().String()) + connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn) tcpKeepAlive.configure(connCtx, conn) if err := s.pgServer.ServeConn(connCtx, conn, pgwire.SocketTCP); err != nil { @@ -1920,7 +1919,7 @@ func (s *sqlServer) startServeSQL( stopper.RunWorker(pgCtx, func(pgCtx context.Context) { netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, unixLn, func(conn net.Conn) { - connCtx := logtags.AddTag(pgCtx, "client", conn.RemoteAddr().String()) + connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn) if err := s.pgServer.ServeConn(connCtx, conn, pgwire.SocketUnix); err != nil { log.Errorf(connCtx, "%v", err) } diff --git a/pkg/sql/pgwire/auth.go b/pkg/sql/pgwire/auth.go index 31c1b37bd0c9..1cfe5a1e5444 100644 --- a/pkg/sql/pgwire/auth.go +++ b/pkg/sql/pgwire/auth.go @@ -176,9 +176,9 @@ func (c *conn) lookupAuthenticationMethodUsingRules( var ip net.IP if connType != hba.ConnLocal { // Extract the IP address of the client. - tcpAddr, ok := c.conn.RemoteAddr().(*net.TCPAddr) + tcpAddr, ok := c.sessionArgs.RemoteAddr.(*net.TCPAddr) if !ok { - err = errors.AssertionFailedf("client address type %T unsupported", c.conn.RemoteAddr()) + err = errors.AssertionFailedf("client address type %T unsupported", c.sessionArgs.RemoteAddr) return } ip = tcpAddr.IP diff --git a/pkg/sql/pgwire/auth_test.go b/pkg/sql/pgwire/auth_test.go index c530f8b79e1d..ebf5af95e1b9 100644 --- a/pkg/sql/pgwire/auth_test.go +++ b/pkg/sql/pgwire/auth_test.go @@ -25,6 +25,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security" @@ -37,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/datadriven" "github.com/cockroachdb/errors" "github.com/cockroachdb/errors/stdstrings" @@ -160,9 +162,9 @@ func hbaRunTest(t *testing.T, insecure bool) { // We can't use the cluster settings to do this, because // cluster settings propagate asynchronously. testServer := s.(*server.TestServer) - testServer.PGServer().TestingEnableConnAuthLogging() - pgServer := s.(*server.TestServer).PGServer() + pgServer.TestingEnableConnLogging() + pgServer.TestingEnableAuthLogging() httpClient, err := s.GetAdminAuthenticatedHTTPClient() if err != nil { @@ -284,6 +286,7 @@ func hbaRunTest(t *testing.T, insecure bool) { // The tag part is going to contain a client address, with a random port number. // To make the test deterministic, erase the random part. tags := addrRe.ReplaceAllString(entry.Tags, ",client=XXX") + tags = peerRe.ReplaceAllString(tags, ",peer=XXX") var maybeTags string if len(tags) > 0 { maybeTags = "[" + tags + "] " @@ -410,6 +413,7 @@ func hbaRunTest(t *testing.T, insecure bool) { var authLogFileRe = regexp.MustCompile(`pgwire/(auth|conn|server)\.go`) var addrRe = regexp.MustCompile(`,client(=[^\],]*)?`) +var peerRe = regexp.MustCompile(`,peer(=[^\],]*)?`) var durationRe = regexp.MustCompile(`duration: \d.*s`) // fmtErr formats an error into an expected output. @@ -435,3 +439,164 @@ func fmtErr(err error) string { } return "ok" } + +// TestClientAddrOverride checks that the crdb:remote_addr parameter +// can override the client address. +func TestClientAddrOverride(t *testing.T) { + defer leaktest.AfterTest(t)() + sc := log.ScopeWithoutShowLogs(t) + defer sc.Close(t) + + // Start a server. + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + ctx := context.Background() + defer s.Stopper().Stop(ctx) + + pgURL, cleanupFunc := sqlutils.PGUrl( + t, s.ServingSQLAddr(), "testClientAddrOverride" /* prefix */, url.User(server.TestUser), + ) + defer cleanupFunc() + + // Ensure the test user exists. + if _, err := db.Exec(`CREATE USER $1`, server.TestUser); err != nil { + t.Fatal(err) + } + + // Enable conn/auth logging. + // We can't use the cluster settings to do this, because + // cluster settings for booleans propagate asynchronously. + testServer := s.(*server.TestServer) + pgServer := testServer.PGServer() + pgServer.TestingEnableAuthLogging() + + testCases := []struct { + specialAddr string + specialPort string + }{ + {"11.22.33.44", "5566"}, // IPv4 + {"[11:22:33::44]", "5566"}, // IPv6 + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("%s:%s", tc.specialAddr, tc.specialPort), func(t *testing.T) { + // Create a custom HBA rule to refuse connections by the testuser + // when coming from the special address. + addr := tc.specialAddr + mask := "32" + if addr[0] == '[' { + // An IPv6 address. The CIDR format in HBA rules does not + // require the square brackets. + addr = addr[1 : len(addr)-1] + mask = "128" + } + hbaConf := "host all " + server.TestUser + " " + addr + "/" + mask + " reject\n" + + "host all all all cert-password\n" + if _, err := db.Exec( + `SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, + hbaConf, + ); err != nil { + t.Fatal(err) + } + + // Wait until the configuration has propagated back to the + // test client. We need to wait because the cluster setting + // change propagates asynchronously. + expConf, err := pgwire.ParseAndNormalize(hbaConf) + if err != nil { + // The SET above succeeded so we don't expect a problem here. + t.Fatal(err) + } + testutils.SucceedsSoon(t, func() error { + curConf := pgServer.GetAuthenticationConfiguration() + if expConf.String() != curConf.String() { + return errors.Newf( + "HBA config not yet loaded\ngot:\n%s\nexpected:\n%s", + curConf, expConf) + } + return nil + }) + + // Inject the custom client address. + options, _ := url.ParseQuery(pgURL.RawQuery) + options["crdb:remote_addr"] = []string{tc.specialAddr + ":" + tc.specialPort} + pgURL.RawQuery = options.Encode() + + t.Run("check-server-reject-override", func(t *testing.T) { + // Connect a first time, with trust override disabled. In that case, + // the server will complain that the remote override is not supported. + _ = pgServer.TestingSetTrustClientProvidedRemoteAddr(false) + + testDB, err := gosql.Open("postgres", pgURL.String()) + if err != nil { + t.Fatal(err) + } + defer testDB.Close() + if err := testDB.Ping(); !testutils.IsError(err, "server not configured to accept remote address override") { + t.Error(err) + } + }) + + // Wait two full microseconds: we're parsing the log output below, and + // the logging format has a microsecond precision on timestamps. We need to ensure that this check will not pick up log entries + // from a previous test. + time.Sleep(2 * time.Microsecond) + testStartTime := timeutil.Now() + + t.Run("check-server-hba-uses-override", func(t *testing.T) { + // Now recognize the override. Now we're expecting the connection + // to hit the HBA rule and fail with an authentication error. + _ = pgServer.TestingSetTrustClientProvidedRemoteAddr(true) + + testDB, err := gosql.Open("postgres", pgURL.String()) + if err != nil { + t.Fatal(err) + } + defer testDB.Close() + if err := testDB.Ping(); !testutils.IsError(err, "authentication rejected") { + t.Error(err) + } + }) + + t.Run("check-server-log-uses-override", func(t *testing.T) { + // Wait for the disconnection event in logs. + testutils.SucceedsSoon(t, func() error { + log.Flush() + entries, err := log.FetchEntriesFromFiles(testStartTime.UnixNano(), math.MaxInt64, 10000, sessionTerminatedRe, + log.WithFlattenedSensitiveData) + if err != nil { + t.Fatal(err) + } + if len(entries) == 0 { + return errors.New("entry not found") + } + return nil + }) + + // Now we want to check that the logging tags are also updated. + log.Flush() + entries, err := log.FetchEntriesFromFiles(testStartTime.UnixNano(), math.MaxInt64, 10000, authLogFileRe, + log.WithMarkedSensitiveData) + if err != nil { + t.Fatal(err) + } + if len(entries) == 0 { + t.Fatal("no entries") + } + seenClient := false + for _, e := range entries { + t.Log(e.Tags) + if strings.Contains(e.Tags, "client=") { + seenClient = true + if !strings.Contains(e.Tags, "client="+tc.specialAddr+":"+tc.specialPort) { + t.Fatalf("expected override addr in log tags, got %+v", e) + } + } + } + if !seenClient { + t.Fatal("no log entry found with the 'client' tag set") + } + }) + }) + } +} + +var sessionTerminatedRe = regexp.MustCompile("session terminated") diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 0e5db622ee95..c57e3e2ddd17 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -140,14 +140,12 @@ func (s *Server) serveConn( reserved mon.BoundAccount, authOpt authOptions, ) { - sArgs.RemoteAddr = netConn.RemoteAddr() - if log.V(2) { log.Infof(ctx, "new connection with options: %+v", sArgs) } c := newConn(netConn, sArgs, &s.metrics, &s.execCfg.Settings.SV) - c.alwaysLogAuthActivity = alwaysLogAuthActivity || atomic.LoadInt32(&s.testingLogEnabled) > 0 + c.alwaysLogAuthActivity = alwaysLogAuthActivity || atomic.LoadInt32(&s.testingAuthLogEnabled) > 0 // Do the reading of commands from the network. c.serveImpl(ctx, s.IsDraining, s.SQLServer, reserved, authOpt) diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index 50b532a538b4..8d7ec0451784 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -538,7 +538,7 @@ func waitForClientConn(ln net.Listener) (*conn, error) { } // Consume the connection options. - if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf); err != nil { + if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf, conn.RemoteAddr(), false /* trustRemoteAddr */); err != nil { return nil, err } diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index ea9690f10b1d..e12107c637d8 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "io" "net" + "strconv" "strings" "sync/atomic" "time" @@ -181,10 +182,30 @@ type Server struct { sqlMemoryPool *mon.BytesMonitor connMonitor *mon.BytesMonitor - // testingLogEnabled is used in unit tests in this package to - // force-enable conn/auth logging without dancing around the - // asynchronicity of cluster settings. - testingLogEnabled int32 + // testing{Conn,Auth}LogEnabled is used in unit tests in this + // package to force-enable conn/auth logging without dancing around + // the asynchronicity of cluster settings. + testingConnLogEnabled int32 + testingAuthLogEnabled int32 + + // trustClientProvidedRemoteAddr indicates whether the server should honor + // a `crdb:remote_addr` status parameter provided by the client during + // session authentication. This status parameter can be set by SQL proxies + // to feed the "real" client address, where otherwise the CockroachDB SQL + // server would only see the address of the proxy. + // + // This setting is security-sensitive and should not be enabled + // without a SQL proxy that carefully scrubs any client-provided + // `crdb:remote_addr` field. In particular, this setting should never + // be set when there is no SQL proxy at all. Otherwise, a malicious + // client could use this field to pretend being from another address + // than its own and defeat the HBA rules. + // + // TODO(knz,ben): It would be good to have something more specific + // than a boolean, i.e. to accept the provided address only from + // certain peer IPs, or with certain certificates. (could it be a + // special hba.conf directive?) + trustClientProvidedRemoteAddr syncutil.AtomicBool } // ServerMetrics is the set of metrics for the pgwire server. @@ -252,6 +273,9 @@ func MakeServer( server.sqlMemoryPool.Start(context.Background(), parentMemoryMonitor, mon.BoundAccount{}) server.SQLServer = sql.NewServer(executorConfig, server.sqlMemoryPool) + // TODO(knz,ben): Use a cluster setting for this. + server.trustClientProvidedRemoteAddr.Set(trustClientProvidedRemoteAddrOverride) + server.connMonitor = mon.NewMonitor("conn", mon.MemoryResource, server.metrics.ConnMemMetrics.CurBytesCount, @@ -272,6 +296,20 @@ func MakeServer( return server } +// AnnotateCtxForIncomingConn annotates the provided context with a +// tag that reports the peer's address. In the common case, the +// context is annotated with a "client" tag. When the server is +// configured to recognize client-specified remote addresses, it is +// annotated with a "peer" tag and the "client" tag is added later +// when the session is set up. +func (s *Server) AnnotateCtxForIncomingConn(ctx context.Context, conn net.Conn) context.Context { + tag := "client" + if s.trustClientProvidedRemoteAddr.Get() { + tag = "peer" + } + return logtags.AddTag(ctx, tag, conn.RemoteAddr().String()) +} + // Match returns true if rd appears to be a Postgres connection. func Match(rd io.Reader) bool { buf := pgwirebase.MakeReadBuffer() @@ -467,12 +505,17 @@ func (s SocketType) asConnType() (hba.ConnType, error) { } func (s *Server) connLogEnabled() bool { - return atomic.LoadInt32(&s.testingLogEnabled) != 0 || logConnAuth.Get(&s.execCfg.Settings.SV) + return atomic.LoadInt32(&s.testingConnLogEnabled) != 0 || logConnAuth.Get(&s.execCfg.Settings.SV) } -// TestingEnableConnAuthLogging is exported for use in tests. -func (s *Server) TestingEnableConnAuthLogging() { - atomic.StoreInt32(&s.testingLogEnabled, 1) +// TestingEnableConnLogging is exported for use in tests. +func (s *Server) TestingEnableConnLogging() { + atomic.StoreInt32(&s.testingConnLogEnabled, 1) +} + +// TestingEnableAuthLogging is exported for use in tests. +func (s *Server) TestingEnableAuthLogging() { + atomic.StoreInt32(&s.testingAuthLogEnabled, 1) } // ServeConn serves a single connection, driving the handshake process and @@ -571,10 +614,16 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket // Load the client-provided session parameters. var sArgs sql.SessionArgs - if sArgs, err = parseClientProvidedSessionParameters(ctx, &s.execCfg.Settings.SV, &buf); err != nil { + if sArgs, err = parseClientProvidedSessionParameters(ctx, &s.execCfg.Settings.SV, &buf, + conn.RemoteAddr(), s.trustClientProvidedRemoteAddr.Get()); err != nil { return s.sendErr(ctx, conn, err) } + // Populate the client address field in the context tags. + // Only know do we know the remote client address for sure (it may have + // been overridden by a status parameter). + ctx = logtags.AddTag(ctx, "client", sArgs.RemoteAddr.String()) + // If a test is hooking in some authentication option, load it. var testingAuthHook func(context.Context) error if k := s.execCfg.PGWireTestingKnobs; k != nil { @@ -599,10 +648,15 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket // parseClientProvidedSessionParameters reads the incoming k/v pairs // in the startup message into a sql.SessionArgs struct. func parseClientProvidedSessionParameters( - ctx context.Context, sv *settings.Values, buf *pgwirebase.ReadBuffer, + ctx context.Context, + sv *settings.Values, + buf *pgwirebase.ReadBuffer, + origRemoteAddr net.Addr, + trustClientProvidedRemoteAddr bool, ) (sql.SessionArgs, error) { args := sql.SessionArgs{ SessionDefaults: make(map[string]string), + RemoteAddr: origRemoteAddr, } foundBufferSize := false @@ -647,6 +701,29 @@ func parseClientProvidedSessionParameters( } foundBufferSize = true + case "crdb:remote_addr": + if !trustClientProvidedRemoteAddr { + return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation, + "server not configured to accept remote address override (requested: %q)", value) + } + + hostS, portS, err := net.SplitHostPort(value) + if err != nil { + return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation, + "invalid address format: %v", err) + } + port, err := strconv.Atoi(portS) + if err != nil { + return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation, + "remote port is not numeric: %v", err) + } + ip := net.ParseIP(hostS) + if ip == nil { + return sql.SessionArgs{}, pgerror.New(pgcode.ProtocolViolation, + "remote address is not numeric") + } + args.RemoteAddr = &net.TCPAddr{IP: ip, Port: port} + default: exists, configurable := sql.IsSessionVariableConfigurable(key) @@ -682,6 +759,20 @@ func parseClientProvidedSessionParameters( return args, nil } +// Note: Usage of an env var here makes it possible to unconditionally +// enable this feature when cluster settings do not work reliably, +// e.g. in multi-tenant setups in v20.2. This override mechanism can +// be removed after all of CC is moved to use v21.1 or a version which +// supports cluster settings. +var trustClientProvidedRemoteAddrOverride = envutil.EnvOrDefaultBool("COCKROACH_TRUST_CLIENT_PROVIDED_SQL_REMOTE_ADDR", false) + +// TestingSetTrustClientProvidedRemoteAddr is used in tests. +func (s *Server) TestingSetTrustClientProvidedRemoteAddr(b bool) func() { + prev := s.trustClientProvidedRemoteAddr.Get() + s.trustClientProvidedRemoteAddr.Set(b) + return func() { s.trustClientProvidedRemoteAddr.Set(prev) } +} + // maybeUpgradeToSecureConn upgrades the connection to TLS/SSL if // requested by the client, and available in the server configuration. func (s *Server) maybeUpgradeToSecureConn(