diff --git a/pkg/privilege/privileges/ldap/BUILD.bazel b/pkg/privilege/privileges/ldap/BUILD.bazel index 0a6c2957fa3c2..99fe2046d6e70 100644 --- a/pkg/privilege/privileges/ldap/BUILD.bazel +++ b/pkg/privilege/privileges/ldap/BUILD.bazel @@ -12,9 +12,12 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/privilege/conn", + "//pkg/util/intest", + "//pkg/util/logutil", "@com_github_go_ldap_ldap_v3//:ldap", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", + "@org_uber_go_zap//:zap", ], ) @@ -29,6 +32,6 @@ go_test( "test/ldap.key", ], flaky = True, - shard_count = 3, + shard_count = 4, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/pkg/privilege/privileges/ldap/ldap_common.go b/pkg/privilege/privileges/ldap/ldap_common.go index 6eefdb89ef454..b1fe2586481c8 100644 --- a/pkg/privilege/privileges/ldap/ldap_common.go +++ b/pkg/privilege/privileges/ldap/ldap_common.go @@ -22,12 +22,25 @@ import ( "os" "strconv" "sync" + "time" "github.com/go-ldap/ldap/v3" "github.com/ngaut/pools" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" ) +// ldapTimeout is set to 15s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the +// user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection) +// and `TLS`, therefore the total time is 2*ldapTimeout. +var ldapTimeout = 10 * time.Second + +// skipTLSForTest is used to skip trying to connect with TLS directly in tests. If it's set to false, connection will only try to +// use `StartTLS` +var skipTLSForTest = false + // ldapAuthImpl gives the internal utilities of authentication with LDAP. // The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call // should be protected by `impl.Lock()`. @@ -115,10 +128,13 @@ func (impl *ldapAuthImpl) initializeCAPool() error { } func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.Conn, error) { - ldapConnection, err := ldap.Dial("tcp", address) + ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ + Timeout: ldapTimeout, + })) if err != nil { return nil, err } + ldapConnection.SetTimeout(ldapTimeout) err = ldapConnection.StartTLS(&tls.Config{ RootCAs: impl.caPool, @@ -130,18 +146,23 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.C return nil, err } + return ldapConnection, nil } func (impl *ldapAuthImpl) tryConnectLDAPThroughTLS(address string) (*ldap.Conn, error) { - ldapConnection, err := ldap.DialTLS("tcp", address, &tls.Config{ + tlsConfig := &tls.Config{ RootCAs: impl.caPool, ServerName: impl.ldapServerHost, MinVersion: tls.VersionTLS12, - }) + } + ldapConnection, err := ldap.DialURL("ldaps://"+address, ldap.DialWithTLSDialer(tlsConfig, &net.Dialer{ + Timeout: ldapTimeout, + })) if err != nil { return nil, err } + ldapConnection.SetTimeout(ldapTimeout) return ldapConnection, nil } @@ -154,6 +175,10 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { if impl.enableTLS { ldapConnection, err := impl.tryConnectLDAPThroughStartTLS(address) if err != nil { + if intest.InTest && skipTLSForTest { + return nil, err + } + ldapConnection, err = impl.tryConnectLDAPThroughTLS(address) if err != nil { return nil, errors.Wrap(err, "create ldap connection") @@ -162,15 +187,19 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { return ldapConnection, nil } - ldapConnection, err := ldap.Dial("tcp", address) + ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ + Timeout: ldapTimeout, + })) if err != nil { return nil, errors.Wrap(err, "create ldap connection") } + ldapConnection.SetTimeout(ldapTimeout) return ldapConnection, nil } const getConnectionMaxRetry = 10 +const getConnectionRetryInterval = 500 * time.Millisecond func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { retryCount := 0 @@ -191,6 +220,9 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { Password: impl.bindRootPWD, }) if err != nil { + logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err), + zap.Duration("backoff", getConnectionRetryInterval)) + // fail to bind to anonymous user, just release this connection and try to get a new one impl.ldapConnectionPool.Put(nil) @@ -198,6 +230,9 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { if retryCount >= getConnectionMaxRetry { return nil, errors.Wrap(err, "fail to bind to anonymous user") } + // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here. + // TODO: refactor the `RWLock` to avoid the problem of holding the lock. + time.Sleep(getConnectionRetryInterval) continue } @@ -210,12 +245,12 @@ func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) { } func (impl *ldapAuthImpl) initializePool() { - if impl.ldapConnectionPool != nil { - impl.ldapConnectionPool.Close() - } - - // skip initialization when the variables are not correct + // skip re-initialization when the variables are not correct if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity { + if impl.ldapConnectionPool != nil { + impl.ldapConnectionPool.Close() + } + impl.ldapConnectionPool = pools.NewResourcePool(impl.connectionFactory, impl.initCapacity, impl.maxCapacity, 0) } } @@ -260,6 +295,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) { if ldapServerHost != impl.ldapServerHost { impl.ldapServerHost = ldapServerHost + impl.initializePool() } } @@ -270,6 +306,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) { if ldapServerPort != impl.ldapServerPort { impl.ldapServerPort = ldapServerPort + impl.initializePool() } } @@ -280,6 +317,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) { if enableTLS != impl.enableTLS { impl.enableTLS = enableTLS + impl.initializePool() } } diff --git a/pkg/privilege/privileges/ldap/ldap_common_test.go b/pkg/privilege/privileges/ldap/ldap_common_test.go index fd4247cb09499..2d84982d7ffa0 100644 --- a/pkg/privilege/privileges/ldap/ldap_common_test.go +++ b/pkg/privilege/privileges/ldap/ldap_common_test.go @@ -24,6 +24,7 @@ import ( "net" "sync" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -172,3 +173,64 @@ func TestConnectWithTLS11(t *testing.T) { _, err := impl.connectionFactory() require.ErrorContains(t, err, "protocol version not supported") } + +func TestLDAPStartTLSTimeout(t *testing.T) { + originalTimeout := ldapTimeout + ldapTimeout = time.Second * 2 + skipTLSForTest = true + defer func() { + ldapTimeout = originalTimeout + skipTLSForTest = false + }() + + var ln net.Listener + startListen := make(chan struct{}) + afterTimeout := make(chan struct{}) + defer close(afterTimeout) + + // this test only tests whether the LDAP with LTS enabled will fallback from StartTLS + randomTLSServicePort := rand.Int()%10000 + 10000 + serverWg := &sync.WaitGroup{} + serverWg.Add(1) + go func() { + var err error + defer close(startListen) + defer serverWg.Done() + + ln, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", randomTLSServicePort)) + require.NoError(t, err) + startListen <- struct{}{} + + conn, err := ln.Accept() + require.NoError(t, err) + + <-afterTimeout + require.NoError(t, conn.Close()) + + // close the server + require.NoError(t, ln.Close()) + }() + + <-startListen + defer func() { + serverWg.Wait() + }() + + impl := &ldapAuthImpl{} + impl.SetEnableTLS(true) + impl.SetLDAPServerHost("localhost") + impl.SetLDAPServerPort(randomTLSServicePort) + + impl.caPool = x509.NewCertPool() + require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr)) + impl.SetInitCapacity(1) + impl.SetMaxCapacity(1) + + now := time.Now() + _, err := impl.connectionFactory() + afterTimeout <- struct{}{} + dur := time.Since(now) + require.Greater(t, dur, 2*time.Second) + require.Less(t, dur, 3*time.Second) + require.ErrorContains(t, err, "connection timed out") +}