Skip to content

Commit

Permalink
ldap: add timeout and retry-backoff for ldap (#51927)
Browse files Browse the repository at this point in the history
close #51883
  • Loading branch information
YangKeao authored Mar 20, 2024
1 parent 968f4f2 commit d940619
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 10 deletions.
5 changes: 4 additions & 1 deletion pkg/privilege/privileges/ldap/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/privilege/conn",
"//pkg/util/intest",
"//pkg/util/logutil",
"@com_github_go_ldap_ldap_v3//:ldap",
"@com_github_ngaut_pools//:pools",
"@com_github_pingcap_errors//:errors",
"@org_uber_go_zap//:zap",
],
)

Expand All @@ -29,6 +32,6 @@ go_test(
"test/ldap.key",
],
flaky = True,
shard_count = 3,
shard_count = 4,
deps = ["@com_github_stretchr_testify//require"],
)
56 changes: 47 additions & 9 deletions pkg/privilege/privileges/ldap/ldap_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,25 @@ import (
"os"
"strconv"
"sync"
"time"

"github.com/go-ldap/ldap/v3"
"github.com/ngaut/pools"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)

// ldapTimeout is set to 10s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the
// user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection)
// and `TLS`, therefore the total time is 2*ldapTimeout.
var ldapTimeout = 10 * time.Second

// skipTLSForTest is used to skip trying to connect with TLS directly in tests. If it's set to false, connection will only try to
// use `StartTLS`
var skipTLSForTest = false

// ldapAuthImpl gives the internal utilities of authentication with LDAP.
// The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call
// should be protected by `impl.Lock()`.
Expand Down Expand Up @@ -115,10 +128,13 @@ func (impl *ldapAuthImpl) initializeCAPool() error {
}

func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.Conn, error) {
ldapConnection, err := ldap.Dial("tcp", address)
ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{
Timeout: ldapTimeout,
}))
if err != nil {
return nil, err
}
ldapConnection.SetTimeout(ldapTimeout)

err = ldapConnection.StartTLS(&tls.Config{
RootCAs: impl.caPool,
Expand All @@ -130,18 +146,23 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.C

return nil, err
}

return ldapConnection, nil
}

func (impl *ldapAuthImpl) tryConnectLDAPThroughTLS(address string) (*ldap.Conn, error) {
ldapConnection, err := ldap.DialTLS("tcp", address, &tls.Config{
tlsConfig := &tls.Config{
RootCAs: impl.caPool,
ServerName: impl.ldapServerHost,
MinVersion: tls.VersionTLS12,
})
}
ldapConnection, err := ldap.DialURL("ldaps://"+address, ldap.DialWithTLSDialer(tlsConfig, &net.Dialer{
Timeout: ldapTimeout,
}))
if err != nil {
return nil, err
}
ldapConnection.SetTimeout(ldapTimeout)

return ldapConnection, nil
}
Expand All @@ -154,6 +175,10 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {
if impl.enableTLS {
ldapConnection, err := impl.tryConnectLDAPThroughStartTLS(address)
if err != nil {
if intest.InTest && skipTLSForTest {
return nil, err
}

ldapConnection, err = impl.tryConnectLDAPThroughTLS(address)
if err != nil {
return nil, errors.Wrap(err, "create ldap connection")
Expand All @@ -162,15 +187,19 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {

return ldapConnection, nil
}
ldapConnection, err := ldap.Dial("tcp", address)
ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{
Timeout: ldapTimeout,
}))
if err != nil {
return nil, errors.Wrap(err, "create ldap connection")
}
ldapConnection.SetTimeout(ldapTimeout)

return ldapConnection, nil
}

const getConnectionMaxRetry = 10
const getConnectionRetryInterval = 500 * time.Millisecond

func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
retryCount := 0
Expand All @@ -191,13 +220,19 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
Password: impl.bindRootPWD,
})
if err != nil {
logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err),
zap.Duration("backoff", getConnectionRetryInterval))

// fail to bind to anonymous user, just release this connection and try to get a new one
impl.ldapConnectionPool.Put(nil)

retryCount++
if retryCount >= getConnectionMaxRetry {
return nil, errors.Wrap(err, "fail to bind to anonymous user")
}
// Be careful that it's still holding the lock of the system variables, so it's not good to sleep here.
// TODO: refactor the `RWLock` to avoid the problem of holding the lock.
time.Sleep(getConnectionRetryInterval)
continue
}

Expand All @@ -210,12 +245,12 @@ func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) {
}

func (impl *ldapAuthImpl) initializePool() {
if impl.ldapConnectionPool != nil {
impl.ldapConnectionPool.Close()
}

// skip initialization when the variables are not correct
// skip re-initialization when the variables are not correct
if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity {
if impl.ldapConnectionPool != nil {
impl.ldapConnectionPool.Close()
}

impl.ldapConnectionPool = pools.NewResourcePool(impl.connectionFactory, impl.initCapacity, impl.maxCapacity, 0)
}
}
Expand Down Expand Up @@ -260,6 +295,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) {

if ldapServerHost != impl.ldapServerHost {
impl.ldapServerHost = ldapServerHost
impl.initializePool()
}
}

Expand All @@ -270,6 +306,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) {

if ldapServerPort != impl.ldapServerPort {
impl.ldapServerPort = ldapServerPort
impl.initializePool()
}
}

Expand All @@ -280,6 +317,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) {

if enableTLS != impl.enableTLS {
impl.enableTLS = enableTLS
impl.initializePool()
}
}

Expand Down
62 changes: 62 additions & 0 deletions pkg/privilege/privileges/ldap/ldap_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -172,3 +173,64 @@ func TestConnectWithTLS11(t *testing.T) {
_, err := impl.connectionFactory()
require.ErrorContains(t, err, "protocol version not supported")
}

func TestLDAPStartTLSTimeout(t *testing.T) {
originalTimeout := ldapTimeout
ldapTimeout = time.Second * 2
skipTLSForTest = true
defer func() {
ldapTimeout = originalTimeout
skipTLSForTest = false
}()

var ln net.Listener
startListen := make(chan struct{})
afterTimeout := make(chan struct{})
defer close(afterTimeout)

// this test only tests whether the LDAP with LTS enabled will fallback from StartTLS
randomTLSServicePort := rand.Int()%10000 + 10000
serverWg := &sync.WaitGroup{}
serverWg.Add(1)
go func() {
var err error
defer close(startListen)
defer serverWg.Done()

ln, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", randomTLSServicePort))
require.NoError(t, err)
startListen <- struct{}{}

conn, err := ln.Accept()
require.NoError(t, err)

<-afterTimeout
require.NoError(t, conn.Close())

// close the server
require.NoError(t, ln.Close())
}()

<-startListen
defer func() {
serverWg.Wait()
}()

impl := &ldapAuthImpl{}
impl.SetEnableTLS(true)
impl.SetLDAPServerHost("localhost")
impl.SetLDAPServerPort(randomTLSServicePort)

impl.caPool = x509.NewCertPool()
require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr))
impl.SetInitCapacity(1)
impl.SetMaxCapacity(1)

now := time.Now()
_, err := impl.connectionFactory()
afterTimeout <- struct{}{}
dur := time.Since(now)
require.Greater(t, dur, 2*time.Second)
require.Less(t, dur, 3*time.Second)
require.ErrorContains(t, err, "connection timed out")
}

0 comments on commit d940619

Please sign in to comment.