Skip to content

Commit

Permalink
Merge pull request cockroachdb#58380 from knz/backport20.2-58379
Browse files Browse the repository at this point in the history
  • Loading branch information
knz authored Jan 8, 2021
2 parents 38fe95d + c49afe4 commit bd0ed16
Show file tree
Hide file tree
Showing 11 changed files with 380 additions and 52 deletions.
5 changes: 2 additions & 3 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/logtags"
"github.com/cockroachdb/redact"
"github.com/cockroachdb/sentry-go"
gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime"
Expand Down Expand Up @@ -1892,7 +1891,7 @@ func (s *sqlServer) startServeSQL(

stopper.RunWorker(pgCtx, func(pgCtx context.Context) {
netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, pgL, func(conn net.Conn) {
connCtx := logtags.AddTag(pgCtx, "client", conn.RemoteAddr().String())
connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn)
tcpKeepAlive.configure(connCtx, conn)

if err := s.pgServer.ServeConn(connCtx, conn, pgwire.SocketTCP); err != nil {
Expand Down Expand Up @@ -1920,7 +1919,7 @@ func (s *sqlServer) startServeSQL(

stopper.RunWorker(pgCtx, func(pgCtx context.Context) {
netutil.FatalIfUnexpected(connManager.ServeWith(pgCtx, stopper, unixLn, func(conn net.Conn) {
connCtx := logtags.AddTag(pgCtx, "client", conn.RemoteAddr().String())
connCtx := s.pgServer.AnnotateCtxForIncomingConn(pgCtx, conn)
if err := s.pgServer.ServeConn(connCtx, conn, pgwire.SocketUnix); err != nil {
log.Errorf(connCtx, "%v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/pgwire/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ func (c *conn) lookupAuthenticationMethodUsingRules(
var ip net.IP
if connType != hba.ConnLocal {
// Extract the IP address of the client.
tcpAddr, ok := c.conn.RemoteAddr().(*net.TCPAddr)
tcpAddr, ok := c.sessionArgs.RemoteAddr.(*net.TCPAddr)
if !ok {
err = errors.AssertionFailedf("client address type %T unsupported", c.conn.RemoteAddr())
err = errors.AssertionFailedf("client address type %T unsupported", c.sessionArgs.RemoteAddr)
return
}
ip = tcpAddr.IP
Expand Down
169 changes: 167 additions & 2 deletions pkg/sql/pgwire/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/security"
Expand All @@ -37,6 +38,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
Expand Down Expand Up @@ -160,9 +162,9 @@ func hbaRunTest(t *testing.T, insecure bool) {
// We can't use the cluster settings to do this, because
// cluster settings propagate asynchronously.
testServer := s.(*server.TestServer)
testServer.PGServer().TestingEnableConnAuthLogging()

pgServer := s.(*server.TestServer).PGServer()
pgServer.TestingEnableConnLogging()
pgServer.TestingEnableAuthLogging()

httpClient, err := s.GetAdminAuthenticatedHTTPClient()
if err != nil {
Expand Down Expand Up @@ -284,6 +286,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
// The tag part is going to contain a client address, with a random port number.
// To make the test deterministic, erase the random part.
tags := addrRe.ReplaceAllString(entry.Tags, ",client=XXX")
tags = peerRe.ReplaceAllString(tags, ",peer=XXX")
var maybeTags string
if len(tags) > 0 {
maybeTags = "[" + tags + "] "
Expand Down Expand Up @@ -410,6 +413,7 @@ func hbaRunTest(t *testing.T, insecure bool) {

var authLogFileRe = regexp.MustCompile(`pgwire/(auth|conn|server)\.go`)
var addrRe = regexp.MustCompile(`,client(=[^\],]*)?`)
var peerRe = regexp.MustCompile(`,peer(=[^\],]*)?`)
var durationRe = regexp.MustCompile(`duration: \d.*s`)

// fmtErr formats an error into an expected output.
Expand All @@ -435,3 +439,164 @@ func fmtErr(err error) string {
}
return "ok"
}

// TestClientAddrOverride checks that the crdb:remote_addr parameter
// can override the client address.
func TestClientAddrOverride(t *testing.T) {
defer leaktest.AfterTest(t)()
sc := log.ScopeWithoutShowLogs(t)
defer sc.Close(t)

// Start a server.
s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
ctx := context.Background()
defer s.Stopper().Stop(ctx)

pgURL, cleanupFunc := sqlutils.PGUrl(
t, s.ServingSQLAddr(), "testClientAddrOverride" /* prefix */, url.User(server.TestUser),
)
defer cleanupFunc()

// Ensure the test user exists.
if _, err := db.Exec(`CREATE USER $1`, server.TestUser); err != nil {
t.Fatal(err)
}

// Enable conn/auth logging.
// We can't use the cluster settings to do this, because
// cluster settings for booleans propagate asynchronously.
testServer := s.(*server.TestServer)
pgServer := testServer.PGServer()
pgServer.TestingEnableAuthLogging()

testCases := []struct {
specialAddr string
specialPort string
}{
{"11.22.33.44", "5566"}, // IPv4
{"[11:22:33::44]", "5566"}, // IPv6
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s:%s", tc.specialAddr, tc.specialPort), func(t *testing.T) {
// Create a custom HBA rule to refuse connections by the testuser
// when coming from the special address.
addr := tc.specialAddr
mask := "32"
if addr[0] == '[' {
// An IPv6 address. The CIDR format in HBA rules does not
// require the square brackets.
addr = addr[1 : len(addr)-1]
mask = "128"
}
hbaConf := "host all " + server.TestUser + " " + addr + "/" + mask + " reject\n" +
"host all all all cert-password\n"
if _, err := db.Exec(
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`,
hbaConf,
); err != nil {
t.Fatal(err)
}

// Wait until the configuration has propagated back to the
// test client. We need to wait because the cluster setting
// change propagates asynchronously.
expConf, err := pgwire.ParseAndNormalize(hbaConf)
if err != nil {
// The SET above succeeded so we don't expect a problem here.
t.Fatal(err)
}
testutils.SucceedsSoon(t, func() error {
curConf := pgServer.GetAuthenticationConfiguration()
if expConf.String() != curConf.String() {
return errors.Newf(
"HBA config not yet loaded\ngot:\n%s\nexpected:\n%s",
curConf, expConf)
}
return nil
})

// Inject the custom client address.
options, _ := url.ParseQuery(pgURL.RawQuery)
options["crdb:remote_addr"] = []string{tc.specialAddr + ":" + tc.specialPort}
pgURL.RawQuery = options.Encode()

t.Run("check-server-reject-override", func(t *testing.T) {
// Connect a first time, with trust override disabled. In that case,
// the server will complain that the remote override is not supported.
_ = pgServer.TestingSetTrustClientProvidedRemoteAddr(false)

testDB, err := gosql.Open("postgres", pgURL.String())
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
if err := testDB.Ping(); !testutils.IsError(err, "server not configured to accept remote address override") {
t.Error(err)
}
})

// Wait two full microseconds: we're parsing the log output below, and
// the logging format has a microsecond precision on timestamps. We need to ensure that this check will not pick up log entries
// from a previous test.
time.Sleep(2 * time.Microsecond)
testStartTime := timeutil.Now()

t.Run("check-server-hba-uses-override", func(t *testing.T) {
// Now recognize the override. Now we're expecting the connection
// to hit the HBA rule and fail with an authentication error.
_ = pgServer.TestingSetTrustClientProvidedRemoteAddr(true)

testDB, err := gosql.Open("postgres", pgURL.String())
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
if err := testDB.Ping(); !testutils.IsError(err, "authentication rejected") {
t.Error(err)
}
})

t.Run("check-server-log-uses-override", func(t *testing.T) {
// Wait for the disconnection event in logs.
testutils.SucceedsSoon(t, func() error {
log.Flush()
entries, err := log.FetchEntriesFromFiles(testStartTime.UnixNano(), math.MaxInt64, 10000, sessionTerminatedRe,
log.WithFlattenedSensitiveData)
if err != nil {
t.Fatal(err)
}
if len(entries) == 0 {
return errors.New("entry not found")
}
return nil
})

// Now we want to check that the logging tags are also updated.
log.Flush()
entries, err := log.FetchEntriesFromFiles(testStartTime.UnixNano(), math.MaxInt64, 10000, authLogFileRe,
log.WithMarkedSensitiveData)
if err != nil {
t.Fatal(err)
}
if len(entries) == 0 {
t.Fatal("no entries")
}
seenClient := false
for _, e := range entries {
t.Log(e.Tags)
if strings.Contains(e.Tags, "client=") {
seenClient = true
if !strings.Contains(e.Tags, "client="+tc.specialAddr+":"+tc.specialPort) {
t.Fatalf("expected override addr in log tags, got %+v", e)
}
}
}
if !seenClient {
t.Fatal("no log entry found with the 'client' tag set")
}
})
})
}
}

var sessionTerminatedRe = regexp.MustCompile("session terminated")
20 changes: 12 additions & 8 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/envutil"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/mon"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
Expand Down Expand Up @@ -90,10 +91,8 @@ type conn struct {

sv *settings.Values

// testingLogEnabled is used in unit tests in this package to
// force-enable auth logging without dancing around the
// asynchronicity of cluster settings.
testingLogEnabled bool
// alwaysLogAuthActivity is used force-enables logging of authn events.
alwaysLogAuthActivity bool
}

// serveConn creates a conn that will serve the netConn. It returns once the
Expand Down Expand Up @@ -141,19 +140,24 @@ func (s *Server) serveConn(
reserved mon.BoundAccount,
authOpt authOptions,
) {
sArgs.RemoteAddr = netConn.RemoteAddr()

if log.V(2) {
log.Infof(ctx, "new connection with options: %+v", sArgs)
}

c := newConn(netConn, sArgs, &s.metrics, &s.execCfg.Settings.SV)
c.testingLogEnabled = atomic.LoadInt32(&s.testingLogEnabled) > 0
c.alwaysLogAuthActivity = alwaysLogAuthActivity || atomic.LoadInt32(&s.testingAuthLogEnabled) > 0

// Do the reading of commands from the network.
c.serveImpl(ctx, s.IsDraining, s.SQLServer, reserved, authOpt)
}

// alwaysLogAuthActivity makes it possible to unconditionally enable
// authentication logging when cluster settings do not work reliably,
// e.g. in multi-tenant setups in v20.2. This override mechanism
// can be removed after all of CC is moved to use v21.1 or a version
// which supports cluster settings.
var alwaysLogAuthActivity = envutil.EnvOrDefaultBool("COCKROACH_ALWAYS_LOG_AUTHN_EVENTS", false)

func newConn(
netConn net.Conn, sArgs sql.SessionArgs, metrics *ServerMetrics, sv *settings.Values,
) *conn {
Expand Down Expand Up @@ -188,7 +192,7 @@ func (c *conn) GetErr() error {
}

func (c *conn) authLogEnabled() bool {
return c.testingLogEnabled || logSessionAuth.Get(c.sv)
return c.alwaysLogAuthActivity || logSessionAuth.Get(c.sv)
}

// serveImpl continuously reads from the network connection and pushes execution
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/pgwire/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ func waitForClientConn(ln net.Listener) (*conn, error) {
}

// Consume the connection options.
if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf); err != nil {
if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf, conn.RemoteAddr(), false /* trustRemoteAddr */); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit bd0ed16

Please sign in to comment.