From b2860af23da3adc4d042d60ef977bfc004fe6dfc Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Tue, 19 Mar 2024 20:25:58 +0800 Subject: [PATCH 1/2] add timeout for LDAP dial and requests Signed-off-by: Yang Keao --- pkg/privilege/privileges/ldap/ldap_common.go | 318 ++++++++++-------- .../privileges/ldap/ldap_common_test.go | 23 +- pkg/privilege/privileges/ldap/sasl.go | 41 +-- pkg/privilege/privileges/ldap/simple.go | 15 +- 4 files changed, 205 insertions(+), 192 deletions(-) diff --git a/pkg/privilege/privileges/ldap/ldap_common.go b/pkg/privilege/privileges/ldap/ldap_common.go index a91fc69202abb..0fc0ae8a47d05 100644 --- a/pkg/privilege/privileges/ldap/ldap_common.go +++ b/pkg/privilege/privileges/ldap/ldap_common.go @@ -32,6 +32,9 @@ import ( "go.uber.org/zap" ) +const getConnectionMaxRetry = 10 +const getConnectionRetryInterval = 500 * time.Millisecond + // ldapTimeout is set to 10s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the // user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection) // and `TLS`, therefore the total time is 2*ldapTimeout. @@ -41,10 +44,8 @@ var ldapTimeout = 10 * time.Second // use `StartTLS` var skipTLSForTest = false -// ldapAuthImpl gives the internal utilities of authentication with LDAP. -// The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call -// should be protected by `impl.Lock()`. -type ldapAuthImpl struct { +// ldapAuthImplBuilder builds a new `*ldapAuthImpl` from the current configuration. +type ldapAuthImplBuilder struct { sync.RWMutex // the following attributes are used to search the users bindBaseDN string @@ -65,49 +66,20 @@ type ldapAuthImpl struct { ldapConnectionPool *pools.ResourcePool } -func (impl *ldapAuthImpl) searchUser(userName string) (dn string, err error) { - var l *ldap.Conn - - l, err = impl.getConnection() - if err != nil { - return "", err - } - defer impl.putConnection(l) - - err = l.Bind(impl.bindRootDN, impl.bindRootPWD) - if err != nil { - return "", errors.Wrap(err, "bind root dn to search user") - } - - result, err := l.Search(&ldap.SearchRequest{ - BaseDN: impl.bindBaseDN, - Scope: ldap.ScopeWholeSubtree, - Filter: fmt.Sprintf("(%s=%s)", impl.searchAttr, userName), - }) - if err != nil { - return - } +func (impl *ldapAuthImplBuilder) build() *ldapAuthImpl { + impl.RLock() + defer impl.RUnlock() - if len(result.Entries) == 0 { - return "", errors.New("LDAP user not found") + return &ldapAuthImpl{ + bindBaseDN: impl.bindBaseDN, + bindRootDN: impl.bindRootDN, + bindRootPWD: impl.bindRootPWD, + searchAttr: impl.searchAttr, + ldapConnectionPool: impl.ldapConnectionPool, } - - dn = result.Entries[0].DN - return } -// canonicalizeDN turns the `dn` provided in database to the `dn` recognized by LDAP server -// If the first byte of `dn` is `+`, it'll be converted into "${searchAttr}=${username},..." -// both `userName` and `dn` should be non-empty -func (impl *ldapAuthImpl) canonicalizeDN(userName string, dn string) string { - if dn[0] == '+' { - return fmt.Sprintf("%s=%s,%s", impl.searchAttr, userName, dn[1:]) - } - - return dn -} - -func (impl *ldapAuthImpl) initializeCAPool() error { +func (impl *ldapAuthImplBuilder) initializeCAPool() error { if impl.caPath == "" { impl.caPool = nil return nil @@ -127,7 +99,7 @@ func (impl *ldapAuthImpl) initializeCAPool() error { return nil } -func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.Conn, error) { +func tryConnectLDAPThroughStartTLS(address string, tlsConfig *tls.Config) (*ldap.Conn, error) { ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ Timeout: ldapTimeout, })) @@ -136,26 +108,16 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.C } ldapConnection.SetTimeout(ldapTimeout) - err = ldapConnection.StartTLS(&tls.Config{ - RootCAs: impl.caPool, - ServerName: impl.ldapServerHost, - MinVersion: tls.VersionTLS12, - }) + err = ldapConnection.StartTLS(tlsConfig) if err != nil { ldapConnection.Close() return nil, err } - return ldapConnection, nil } -func (impl *ldapAuthImpl) tryConnectLDAPThroughTLS(address string) (*ldap.Conn, error) { - tlsConfig := &tls.Config{ - RootCAs: impl.caPool, - ServerName: impl.ldapServerHost, - MinVersion: tls.VersionTLS12, - } +func tryConnectLDAPThroughTLS(address string, tlsConfig *tls.Config) (*ldap.Conn, error) { ldapConnection, err := ldap.DialURL("ldaps://"+address, ldap.DialWithTLSDialer(tlsConfig, &net.Dialer{ Timeout: ldapTimeout, })) @@ -167,96 +129,62 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughTLS(address string) (*ldap.Conn, return ldapConnection, nil } -func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { - address := net.JoinHostPort(impl.ldapServerHost, strconv.FormatUint(uint64(impl.ldapServerPort), 10)) - - // It's fine to load these two TLS configurations one-by-one (but not guarded by a single lock), because there isn't - // a way to set two variables atomically. - if impl.enableTLS { - ldapConnection, err := impl.tryConnectLDAPThroughStartTLS(address) - if err != nil { - if intest.InTest && skipTLSForTest { - return nil, err - } - - ldapConnection, err = impl.tryConnectLDAPThroughTLS(address) +func ldapConnectionFactory(address string, tlsConfig *tls.Config) func() (pools.Resource, error) { + if tlsConfig != nil { + return func() (pools.Resource, error) { + ldapConnection, err := tryConnectLDAPThroughStartTLS(address, tlsConfig) if err != nil { - return nil, errors.Wrap(err, "create ldap connection") + if intest.InTest && skipTLSForTest { + return nil, errors.Wrap(err, "create ldap connection") + } + + ldapConnection, err = tryConnectLDAPThroughTLS(address, tlsConfig) + if err != nil { + return nil, errors.Wrap(err, "create ldap connection") + } } - } - return ldapConnection, nil - } - ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ - Timeout: ldapTimeout, - })) - if err != nil { - return nil, errors.Wrap(err, "create ldap connection") - } - ldapConnection.SetTimeout(ldapTimeout) - - return ldapConnection, nil -} - -const getConnectionMaxRetry = 10 -const getConnectionRetryInterval = 500 * time.Millisecond - -func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { - retryCount := 0 - for { - conn, err := impl.ldapConnectionPool.Get() - if err != nil { - return nil, err + return ldapConnection, nil } + } - // try to bind root user. It has two meanings: - // 1. Clear the state of previous binding, to avoid security leaks. (Though it's not serious, because even the current - // connection has binded to other users, the following authentication will still fail. But the ACL for root - // user and a valid user could be different, so it's better to bind back to root user here. - // 2. Detect whether this connection is still valid to use, in case the server has closed this connection. - ldapConnection := conn.(*ldap.Conn) - _, err = ldapConnection.SimpleBind(&ldap.SimpleBindRequest{ - Username: impl.bindRootDN, - Password: impl.bindRootPWD, - }) + return func() (pools.Resource, error) { + ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer( + &net.Dialer{ + Timeout: ldapTimeout, + }, + )) if err != nil { - logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err), - zap.Duration("backoff", getConnectionRetryInterval)) - - // fail to bind to anonymous user, just release this connection and try to get a new one - impl.ldapConnectionPool.Put(nil) - - retryCount++ - if retryCount >= getConnectionMaxRetry { - return nil, errors.Wrap(err, "fail to bind to anonymous user") - } - // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here. - // TODO: refactor the `RWLock` to avoid the problem of holding the lock. - time.Sleep(getConnectionRetryInterval) - continue + return nil, errors.Wrap(err, "create ldap connection") } - return conn.(*ldap.Conn), nil + return ldapConnection, nil } } -func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) { - impl.ldapConnectionPool.Put(conn) -} +func (impl *ldapAuthImplBuilder) initializePool() { + if impl.ldapConnectionPool != nil { + impl.ldapConnectionPool.Close() + } -func (impl *ldapAuthImpl) initializePool() { - // skip re-initialization when the variables are not correct + // skip initialization when the variables are not correct if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity { - if impl.ldapConnectionPool != nil { - impl.ldapConnectionPool.Close() + address := net.JoinHostPort(impl.ldapServerHost, strconv.FormatUint(uint64(impl.ldapServerPort), 10)) + var tlsConfig *tls.Config + if impl.enableTLS { + tlsConfig = &tls.Config{ + RootCAs: impl.caPool, + ServerName: impl.ldapServerHost, + MinVersion: tls.VersionTLS12, + } } - impl.ldapConnectionPool = pools.NewResourcePool(impl.connectionFactory, impl.initCapacity, impl.maxCapacity, 0) + impl.ldapConnectionPool = pools.NewResourcePool(ldapConnectionFactory(address, tlsConfig), impl.initCapacity, impl.maxCapacity, 0) } } // SetBindBaseDN updates the BaseDN used to search the user -func (impl *ldapAuthImpl) SetBindBaseDN(bindBaseDN string) { +func (impl *ldapAuthImplBuilder) SetBindBaseDN(bindBaseDN string) { impl.Lock() defer impl.Unlock() @@ -265,7 +193,7 @@ func (impl *ldapAuthImpl) SetBindBaseDN(bindBaseDN string) { // SetBindRootDN updates the RootDN. Before searching the users, the connection will bind // this root user. -func (impl *ldapAuthImpl) SetBindRootDN(bindRootDN string) { +func (impl *ldapAuthImplBuilder) SetBindRootDN(bindRootDN string) { impl.Lock() defer impl.Unlock() @@ -273,7 +201,7 @@ func (impl *ldapAuthImpl) SetBindRootDN(bindRootDN string) { } // SetBindRootPW updates the password of the user specified by `rootDN`. -func (impl *ldapAuthImpl) SetBindRootPW(bindRootPW string) { +func (impl *ldapAuthImplBuilder) SetBindRootPW(bindRootPW string) { impl.Lock() defer impl.Unlock() @@ -281,7 +209,7 @@ func (impl *ldapAuthImpl) SetBindRootPW(bindRootPW string) { } // SetSearchAttr updates the search attributes. -func (impl *ldapAuthImpl) SetSearchAttr(searchAttr string) { +func (impl *ldapAuthImplBuilder) SetSearchAttr(searchAttr string) { impl.Lock() defer impl.Unlock() @@ -289,7 +217,7 @@ func (impl *ldapAuthImpl) SetSearchAttr(searchAttr string) { } // SetLDAPServerHost updates the host of LDAP server -func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) { +func (impl *ldapAuthImplBuilder) SetLDAPServerHost(ldapServerHost string) { impl.Lock() defer impl.Unlock() @@ -300,7 +228,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) { } // SetLDAPServerPort updates the port of LDAP server -func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) { +func (impl *ldapAuthImplBuilder) SetLDAPServerPort(ldapServerPort int) { impl.Lock() defer impl.Unlock() @@ -311,7 +239,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) { } // SetEnableTLS sets whether to enable StartTLS for LDAP connection -func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) { +func (impl *ldapAuthImplBuilder) SetEnableTLS(enableTLS bool) { impl.Lock() defer impl.Unlock() @@ -322,7 +250,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) { } // SetCAPath sets the path of CA certificate used to connect to LDAP server -func (impl *ldapAuthImpl) SetCAPath(path string) error { +func (impl *ldapAuthImplBuilder) SetCAPath(path string) error { impl.Lock() defer impl.Unlock() @@ -332,12 +260,13 @@ func (impl *ldapAuthImpl) SetCAPath(path string) error { if err != nil { return err } + impl.initializePool() } return nil } -func (impl *ldapAuthImpl) SetInitCapacity(initCapacity int) { +func (impl *ldapAuthImplBuilder) SetInitCapacity(initCapacity int) { impl.Lock() defer impl.Unlock() @@ -347,7 +276,7 @@ func (impl *ldapAuthImpl) SetInitCapacity(initCapacity int) { } } -func (impl *ldapAuthImpl) SetMaxCapacity(maxCapacity int) { +func (impl *ldapAuthImplBuilder) SetMaxCapacity(maxCapacity int) { impl.Lock() defer impl.Unlock() @@ -358,7 +287,7 @@ func (impl *ldapAuthImpl) SetMaxCapacity(maxCapacity int) { } // GetBindBaseDN returns the BaseDN used to search the user -func (impl *ldapAuthImpl) GetBindBaseDN() string { +func (impl *ldapAuthImplBuilder) GetBindBaseDN() string { impl.RLock() defer impl.RUnlock() @@ -367,7 +296,7 @@ func (impl *ldapAuthImpl) GetBindBaseDN() string { // GetBindRootDN returns the RootDN. Before searching the users, the connection will bind // this root user. -func (impl *ldapAuthImpl) GetBindRootDN() string { +func (impl *ldapAuthImplBuilder) GetBindRootDN() string { impl.RLock() defer impl.RUnlock() @@ -375,7 +304,7 @@ func (impl *ldapAuthImpl) GetBindRootDN() string { } // GetBindRootPW returns the password of the user specified by `rootDN`. -func (impl *ldapAuthImpl) GetBindRootPW() string { +func (impl *ldapAuthImplBuilder) GetBindRootPW() string { impl.RLock() defer impl.RUnlock() @@ -383,7 +312,7 @@ func (impl *ldapAuthImpl) GetBindRootPW() string { } // GetSearchAttr returns the search attributes. -func (impl *ldapAuthImpl) GetSearchAttr() string { +func (impl *ldapAuthImplBuilder) GetSearchAttr() string { impl.RLock() defer impl.RUnlock() @@ -391,7 +320,7 @@ func (impl *ldapAuthImpl) GetSearchAttr() string { } // GetLDAPServerHost returns the host of LDAP server -func (impl *ldapAuthImpl) GetLDAPServerHost() string { +func (impl *ldapAuthImplBuilder) GetLDAPServerHost() string { impl.RLock() defer impl.RUnlock() @@ -399,7 +328,7 @@ func (impl *ldapAuthImpl) GetLDAPServerHost() string { } // GetLDAPServerPort returns the port of LDAP server -func (impl *ldapAuthImpl) GetLDAPServerPort() int { +func (impl *ldapAuthImplBuilder) GetLDAPServerPort() int { impl.RLock() defer impl.RUnlock() @@ -407,7 +336,7 @@ func (impl *ldapAuthImpl) GetLDAPServerPort() int { } // GetEnableTLS sets whether to enable StartTLS for LDAP connection -func (impl *ldapAuthImpl) GetEnableTLS() bool { +func (impl *ldapAuthImplBuilder) GetEnableTLS() bool { impl.RLock() defer impl.RUnlock() @@ -415,23 +344,120 @@ func (impl *ldapAuthImpl) GetEnableTLS() bool { } // GetCAPath returns the path of CA certificate used to connect to LDAP server -func (impl *ldapAuthImpl) GetCAPath() string { +func (impl *ldapAuthImplBuilder) GetCAPath() string { impl.RLock() defer impl.RUnlock() return impl.caPath } -func (impl *ldapAuthImpl) GetInitCapacity() int { +func (impl *ldapAuthImplBuilder) GetInitCapacity() int { impl.RLock() defer impl.RUnlock() return impl.initCapacity } -func (impl *ldapAuthImpl) GetMaxCapacity() int { +func (impl *ldapAuthImplBuilder) GetMaxCapacity() int { impl.RLock() defer impl.RUnlock() return impl.maxCapacity } + +// ldapAuthImpl gives the internal utilities of authentication with LDAP. +type ldapAuthImpl struct { + bindBaseDN string + bindRootDN string + bindRootPWD string + searchAttr string + + ldapConnectionPool *pools.ResourcePool +} + +func (impl *ldapAuthImpl) searchUser(userName string) (dn string, err error) { + var l *ldap.Conn + + l, err = impl.getConnection() + if err != nil { + logutil.BgLogger().Error("fail to create ldap connection", zap.Error(err)) + return "", err + } + defer impl.putConnection(l) + + err = l.Bind(impl.bindRootDN, impl.bindRootPWD) + if err != nil { + return "", errors.Wrap(err, "bind root dn to search user") + } + + result, err := l.Search(&ldap.SearchRequest{ + BaseDN: impl.bindBaseDN, + Scope: ldap.ScopeWholeSubtree, + Filter: fmt.Sprintf("(%s=%s)", impl.searchAttr, userName), + }) + if err != nil { + return + } + + if len(result.Entries) == 0 { + return "", errors.New("LDAP user not found") + } + + dn = result.Entries[0].DN + return +} + +// canonicalizeDN turns the `dn` provided in database to the `dn` recognized by LDAP server +// If the first byte of `dn` is `+`, it'll be converted into "${searchAttr}=${username},..." +// both `userName` and `dn` should be non-empty +func (impl *ldapAuthImpl) canonicalizeDN(userName string, dn string) string { + if dn[0] == '+' { + return fmt.Sprintf("%s=%s,%s", impl.searchAttr, userName, dn[1:]) + } + + return dn +} + +func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { + retryCount := 0 + for { + conn, err := impl.ldapConnectionPool.Get() + if err != nil { + return nil, err + } + + // try to bind root user. It has two meanings: + // 1. Clear the state of previous binding, to avoid security leaks. (Though it's not serious, because even the current + // connection has binded to other users, the following authentication will still fail. But the ACL for root + // user and a valid user could be different, so it's better to bind back to root user here. + // 2. Detect whether this connection is still valid to use, in case the server has closed this connection. + ldapConnection := conn.(*ldap.Conn) + _, err = ldapConnection.SimpleBind(&ldap.SimpleBindRequest{ + Username: impl.bindRootDN, + Password: impl.bindRootPWD, + }) + if err != nil { + logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err), + zap.Duration("backoff", getConnectionRetryInterval)) + ldapConnection.Close() + + // fail to bind to anonymous user, just release this connection and try to get a new one + impl.ldapConnectionPool.Put(nil) + + retryCount++ + if retryCount >= getConnectionMaxRetry { + return nil, errors.Wrap(err, "fail to bind to anonymous user") + } + // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here. + // TODO: refactor the `RWLock` to avoid the problem of holding the lock. + time.Sleep(getConnectionRetryInterval) + continue + } + + return conn.(*ldap.Conn), nil + } +} + +func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) { + impl.ldapConnectionPool.Put(conn) +} diff --git a/pkg/privilege/privileges/ldap/ldap_common_test.go b/pkg/privilege/privileges/ldap/ldap_common_test.go index 2d84982d7ffa0..6c5554eb13778 100644 --- a/pkg/privilege/privileges/ldap/ldap_common_test.go +++ b/pkg/privilege/privileges/ldap/ldap_common_test.go @@ -39,9 +39,10 @@ var tlsCrtStr []byte var tlsKeyStr []byte func TestCanonicalizeDN(t *testing.T) { - impl := &ldapAuthImpl{ + implBuilder := &ldapAuthImplBuilder{ searchAttr: "cn", } + impl := implBuilder.build() require.Equal(t, impl.canonicalizeDN("yka", "cn=y,dc=ping,dc=cap"), "cn=y,dc=ping,dc=cap") require.Equal(t, impl.canonicalizeDN("yka", "+dc=ping,dc=cap"), "cn=yka,dc=ping,dc=cap") } @@ -97,15 +98,17 @@ func TestConnectThrough636(t *testing.T) { serverWg.Wait() }() - impl := &ldapAuthImpl{} + impl := &ldapAuthImplBuilder{} impl.SetEnableTLS(true) impl.SetLDAPServerHost("localhost") impl.SetLDAPServerPort(randomTLSServicePort) impl.caPool = x509.NewCertPool() require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr)) + impl.SetInitCapacity(1) + impl.SetMaxCapacity(1) - conn, err := impl.connectionFactory() + conn, err := impl.ldapConnectionPool.Get() require.NoError(t, err) defer conn.Close() } @@ -162,15 +165,17 @@ func TestConnectWithTLS11(t *testing.T) { serverWg.Wait() }() - impl := &ldapAuthImpl{} + impl := &ldapAuthImplBuilder{} impl.SetEnableTLS(true) impl.SetLDAPServerHost("localhost") impl.SetLDAPServerPort(randomTLSServicePort) impl.caPool = x509.NewCertPool() require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr)) + impl.SetInitCapacity(1) + impl.SetMaxCapacity(1) - _, err := impl.connectionFactory() + _, err := impl.ldapConnectionPool.Get() require.ErrorContains(t, err, "protocol version not supported") } @@ -216,7 +221,7 @@ func TestLDAPStartTLSTimeout(t *testing.T) { serverWg.Wait() }() - impl := &ldapAuthImpl{} + impl := &ldapAuthImplBuilder{} impl.SetEnableTLS(true) impl.SetLDAPServerHost("localhost") impl.SetLDAPServerPort(randomTLSServicePort) @@ -227,10 +232,8 @@ func TestLDAPStartTLSTimeout(t *testing.T) { impl.SetMaxCapacity(1) now := time.Now() - _, err := impl.connectionFactory() + _, err := impl.build().getConnection() afterTimeout <- struct{}{} - dur := time.Since(now) - require.Greater(t, dur, 2*time.Second) - require.Less(t, dur, 3*time.Second) + require.Greater(t, time.Since(now), 2*time.Second) require.ErrorContains(t, err, "connection timed out") } diff --git a/pkg/privilege/privileges/ldap/sasl.go b/pkg/privilege/privileges/ldap/sasl.go index 0456ef6849eac..eea762cbe96a1 100644 --- a/pkg/privilege/privileges/ldap/sasl.go +++ b/pkg/privilege/privileges/ldap/sasl.go @@ -17,51 +17,36 @@ package ldap import ( "context" - "github.com/go-ldap/ldap/v3" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/privilege/conn" ) type ldapSASLAuthImpl struct { - *ldapAuthImpl + *ldapAuthImplBuilder saslAuthMethod string } // AuthLDAPSASL authenticates the user through LDAP SASL func (impl *ldapSASLAuthImpl) AuthLDAPSASL(userName string, dn string, clientCred []byte, authConn conn.AuthConn) error { - dn, ldapConn, err := func() (string, *ldap.Conn, error) { - // It's fine to just `RLock` here, even we're fetching resources from the pool, because the resource pool can be - // accessed concurrently. The `RLock` here is to protect the configurations. - // - // It's a bad idea to lock through the whole function, as this function will write/read to interact with the client. - // If the client somehow died and don't send responds, this function will not end for a long time (until connection - // timeout) and this lock will not be released. Therefore, we only `RLock` the configurations in the - impl.RLock() - defer impl.RUnlock() - - var err error - if len(dn) == 0 { - dn, err = impl.searchUser(userName) - if err != nil { - return "", nil, err - } - } else { - dn = impl.canonicalizeDN(userName, dn) - } + var err error + ldapImpl := impl.ldapAuthImplBuilder.build() - ldapConn, err := impl.getConnection() + if len(dn) == 0 { + dn, err = ldapImpl.searchUser(userName) if err != nil { - return "", nil, errors.Wrap(err, "create LDAP connection") + return err } + } else { + dn = ldapImpl.canonicalizeDN(userName, dn) + } - return dn, ldapConn, nil - }() + ldapConn, err := ldapImpl.getConnection() if err != nil { - return err + return errors.Wrap(err, "create LDAP connection") } + defer ldapImpl.putConnection(ldapConn) - defer impl.putConnection(ldapConn) for { resultCode, serverCred, err := ldapConn.ServerBindStep(clientCred, dn, impl.saslAuthMethod, nil) if err != nil { @@ -114,6 +99,6 @@ func (impl *ldapSASLAuthImpl) GetSASLAuthMethod() string { // LDAPSASLAuthImpl is the implementation of authentication with LDAP SASL var LDAPSASLAuthImpl = &ldapSASLAuthImpl{ - &ldapAuthImpl{}, + &ldapAuthImplBuilder{}, "", } diff --git a/pkg/privilege/privileges/ldap/simple.go b/pkg/privilege/privileges/ldap/simple.go index d6dabc64eebda..a76d6bb8cbcad 100644 --- a/pkg/privilege/privileges/ldap/simple.go +++ b/pkg/privilege/privileges/ldap/simple.go @@ -19,14 +19,13 @@ import ( ) type ldapSimplAuthImpl struct { - *ldapAuthImpl + *ldapAuthImplBuilder } // AuthLDAPSimple authenticates the user through LDAP Simple Bind // password is expected to be a nul-terminated string func (impl *ldapSimplAuthImpl) AuthLDAPSimple(userName string, dn string, password []byte) error { - impl.RLock() - defer impl.RUnlock() + ldapImpl := impl.ldapAuthImplBuilder.build() if len(password) == 0 { return errors.New("invalid password") @@ -38,19 +37,19 @@ func (impl *ldapSimplAuthImpl) AuthLDAPSimple(userName string, dn string, passwo var err error if len(dn) == 0 { - dn, err = impl.searchUser(userName) + dn, err = ldapImpl.searchUser(userName) if err != nil { return err } } else { - dn = impl.canonicalizeDN(userName, dn) + dn = ldapImpl.canonicalizeDN(userName, dn) } - ldapConn, err := impl.getConnection() + ldapConn, err := ldapImpl.getConnection() if err != nil { return errors.Wrap(err, "create LDAP connection") } - defer impl.putConnection(ldapConn) + defer ldapImpl.putConnection(ldapConn) err = ldapConn.Bind(dn, passwordStr) if err != nil { return errors.Wrap(err, "bind LDAP") @@ -61,5 +60,5 @@ func (impl *ldapSimplAuthImpl) AuthLDAPSimple(userName string, dn string, passwo // LDAPSimpleAuthImpl is the implementation of authentication with LDAP clear text password var LDAPSimpleAuthImpl = &ldapSimplAuthImpl{ - &ldapAuthImpl{}, + &ldapAuthImplBuilder{}, } From ffa01ca557eb698ed3aff498007270134c85b75f Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Tue, 2 Apr 2024 19:20:39 +0800 Subject: [PATCH 2/2] rename impl to builder Signed-off-by: Yang Keao --- pkg/privilege/privileges/ldap/ldap_common.go | 239 +++++++++---------- 1 file changed, 119 insertions(+), 120 deletions(-) diff --git a/pkg/privilege/privileges/ldap/ldap_common.go b/pkg/privilege/privileges/ldap/ldap_common.go index 0fc0ae8a47d05..c5907ca6964ff 100644 --- a/pkg/privilege/privileges/ldap/ldap_common.go +++ b/pkg/privilege/privileges/ldap/ldap_common.go @@ -66,32 +66,32 @@ type ldapAuthImplBuilder struct { ldapConnectionPool *pools.ResourcePool } -func (impl *ldapAuthImplBuilder) build() *ldapAuthImpl { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) build() *ldapAuthImpl { + builder.RLock() + defer builder.RUnlock() return &ldapAuthImpl{ - bindBaseDN: impl.bindBaseDN, - bindRootDN: impl.bindRootDN, - bindRootPWD: impl.bindRootPWD, - searchAttr: impl.searchAttr, - ldapConnectionPool: impl.ldapConnectionPool, + bindBaseDN: builder.bindBaseDN, + bindRootDN: builder.bindRootDN, + bindRootPWD: builder.bindRootPWD, + searchAttr: builder.searchAttr, + ldapConnectionPool: builder.ldapConnectionPool, } } -func (impl *ldapAuthImplBuilder) initializeCAPool() error { - if impl.caPath == "" { - impl.caPool = nil +func (builder *ldapAuthImplBuilder) initializeCAPool() error { + if builder.caPath == "" { + builder.caPool = nil return nil } - impl.caPool = x509.NewCertPool() - caCert, err := os.ReadFile(impl.caPath) + builder.caPool = x509.NewCertPool() + caCert, err := os.ReadFile(builder.caPath) if err != nil { return errors.Wrapf(err, "read ca certificate at %s", caCert) } - ok := impl.caPool.AppendCertsFromPEM(caCert) + ok := builder.caPool.AppendCertsFromPEM(caCert) if !ok { return errors.New("fail to parse ca certificate") } @@ -162,207 +162,207 @@ func ldapConnectionFactory(address string, tlsConfig *tls.Config) func() (pools. } } -func (impl *ldapAuthImplBuilder) initializePool() { - if impl.ldapConnectionPool != nil { - impl.ldapConnectionPool.Close() - } +func (builder *ldapAuthImplBuilder) initializePool() { + // skip re-initialization when the variables are not correct + if builder.initCapacity > 0 && builder.maxCapacity >= builder.initCapacity { + if builder.ldapConnectionPool != nil { + builder.ldapConnectionPool.Close() + } - // skip initialization when the variables are not correct - if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity { - address := net.JoinHostPort(impl.ldapServerHost, strconv.FormatUint(uint64(impl.ldapServerPort), 10)) + address := net.JoinHostPort(builder.ldapServerHost, strconv.FormatUint(uint64(builder.ldapServerPort), 10)) var tlsConfig *tls.Config - if impl.enableTLS { + if builder.enableTLS { tlsConfig = &tls.Config{ - RootCAs: impl.caPool, - ServerName: impl.ldapServerHost, + RootCAs: builder.caPool, + ServerName: builder.ldapServerHost, MinVersion: tls.VersionTLS12, } } - impl.ldapConnectionPool = pools.NewResourcePool(ldapConnectionFactory(address, tlsConfig), impl.initCapacity, impl.maxCapacity, 0) + builder.ldapConnectionPool = pools.NewResourcePool(ldapConnectionFactory(address, tlsConfig), builder.initCapacity, builder.maxCapacity, 0) } } // SetBindBaseDN updates the BaseDN used to search the user -func (impl *ldapAuthImplBuilder) SetBindBaseDN(bindBaseDN string) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetBindBaseDN(bindBaseDN string) { + builder.Lock() + defer builder.Unlock() - impl.bindBaseDN = bindBaseDN + builder.bindBaseDN = bindBaseDN } // SetBindRootDN updates the RootDN. Before searching the users, the connection will bind // this root user. -func (impl *ldapAuthImplBuilder) SetBindRootDN(bindRootDN string) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetBindRootDN(bindRootDN string) { + builder.Lock() + defer builder.Unlock() - impl.bindRootDN = bindRootDN + builder.bindRootDN = bindRootDN } // SetBindRootPW updates the password of the user specified by `rootDN`. -func (impl *ldapAuthImplBuilder) SetBindRootPW(bindRootPW string) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetBindRootPW(bindRootPW string) { + builder.Lock() + defer builder.Unlock() - impl.bindRootPWD = bindRootPW + builder.bindRootPWD = bindRootPW } // SetSearchAttr updates the search attributes. -func (impl *ldapAuthImplBuilder) SetSearchAttr(searchAttr string) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetSearchAttr(searchAttr string) { + builder.Lock() + defer builder.Unlock() - impl.searchAttr = searchAttr + builder.searchAttr = searchAttr } // SetLDAPServerHost updates the host of LDAP server -func (impl *ldapAuthImplBuilder) SetLDAPServerHost(ldapServerHost string) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetLDAPServerHost(ldapServerHost string) { + builder.Lock() + defer builder.Unlock() - if ldapServerHost != impl.ldapServerHost { - impl.ldapServerHost = ldapServerHost - impl.initializePool() + if ldapServerHost != builder.ldapServerHost { + builder.ldapServerHost = ldapServerHost + builder.initializePool() } } // SetLDAPServerPort updates the port of LDAP server -func (impl *ldapAuthImplBuilder) SetLDAPServerPort(ldapServerPort int) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetLDAPServerPort(ldapServerPort int) { + builder.Lock() + defer builder.Unlock() - if ldapServerPort != impl.ldapServerPort { - impl.ldapServerPort = ldapServerPort - impl.initializePool() + if ldapServerPort != builder.ldapServerPort { + builder.ldapServerPort = ldapServerPort + builder.initializePool() } } // SetEnableTLS sets whether to enable StartTLS for LDAP connection -func (impl *ldapAuthImplBuilder) SetEnableTLS(enableTLS bool) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetEnableTLS(enableTLS bool) { + builder.Lock() + defer builder.Unlock() - if enableTLS != impl.enableTLS { - impl.enableTLS = enableTLS - impl.initializePool() + if enableTLS != builder.enableTLS { + builder.enableTLS = enableTLS + builder.initializePool() } } // SetCAPath sets the path of CA certificate used to connect to LDAP server -func (impl *ldapAuthImplBuilder) SetCAPath(path string) error { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetCAPath(path string) error { + builder.Lock() + defer builder.Unlock() - if path != impl.caPath { - impl.caPath = path - err := impl.initializeCAPool() + if path != builder.caPath { + builder.caPath = path + err := builder.initializeCAPool() if err != nil { return err } - impl.initializePool() + builder.initializePool() } return nil } -func (impl *ldapAuthImplBuilder) SetInitCapacity(initCapacity int) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetInitCapacity(initCapacity int) { + builder.Lock() + defer builder.Unlock() - if initCapacity != impl.initCapacity { - impl.initCapacity = initCapacity - impl.initializePool() + if initCapacity != builder.initCapacity { + builder.initCapacity = initCapacity + builder.initializePool() } } -func (impl *ldapAuthImplBuilder) SetMaxCapacity(maxCapacity int) { - impl.Lock() - defer impl.Unlock() +func (builder *ldapAuthImplBuilder) SetMaxCapacity(maxCapacity int) { + builder.Lock() + defer builder.Unlock() - if maxCapacity != impl.maxCapacity { - impl.maxCapacity = maxCapacity - impl.initializePool() + if maxCapacity != builder.maxCapacity { + builder.maxCapacity = maxCapacity + builder.initializePool() } } // GetBindBaseDN returns the BaseDN used to search the user -func (impl *ldapAuthImplBuilder) GetBindBaseDN() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetBindBaseDN() string { + builder.RLock() + defer builder.RUnlock() - return impl.bindBaseDN + return builder.bindBaseDN } // GetBindRootDN returns the RootDN. Before searching the users, the connection will bind // this root user. -func (impl *ldapAuthImplBuilder) GetBindRootDN() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetBindRootDN() string { + builder.RLock() + defer builder.RUnlock() - return impl.bindRootDN + return builder.bindRootDN } // GetBindRootPW returns the password of the user specified by `rootDN`. -func (impl *ldapAuthImplBuilder) GetBindRootPW() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetBindRootPW() string { + builder.RLock() + defer builder.RUnlock() - return impl.bindRootPWD + return builder.bindRootPWD } // GetSearchAttr returns the search attributes. -func (impl *ldapAuthImplBuilder) GetSearchAttr() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetSearchAttr() string { + builder.RLock() + defer builder.RUnlock() - return impl.searchAttr + return builder.searchAttr } // GetLDAPServerHost returns the host of LDAP server -func (impl *ldapAuthImplBuilder) GetLDAPServerHost() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetLDAPServerHost() string { + builder.RLock() + defer builder.RUnlock() - return impl.ldapServerHost + return builder.ldapServerHost } // GetLDAPServerPort returns the port of LDAP server -func (impl *ldapAuthImplBuilder) GetLDAPServerPort() int { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetLDAPServerPort() int { + builder.RLock() + defer builder.RUnlock() - return impl.ldapServerPort + return builder.ldapServerPort } // GetEnableTLS sets whether to enable StartTLS for LDAP connection -func (impl *ldapAuthImplBuilder) GetEnableTLS() bool { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetEnableTLS() bool { + builder.RLock() + defer builder.RUnlock() - return impl.enableTLS + return builder.enableTLS } // GetCAPath returns the path of CA certificate used to connect to LDAP server -func (impl *ldapAuthImplBuilder) GetCAPath() string { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetCAPath() string { + builder.RLock() + defer builder.RUnlock() - return impl.caPath + return builder.caPath } -func (impl *ldapAuthImplBuilder) GetInitCapacity() int { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetInitCapacity() int { + builder.RLock() + defer builder.RUnlock() - return impl.initCapacity + return builder.initCapacity } -func (impl *ldapAuthImplBuilder) GetMaxCapacity() int { - impl.RLock() - defer impl.RUnlock() +func (builder *ldapAuthImplBuilder) GetMaxCapacity() int { + builder.RLock() + defer builder.RUnlock() - return impl.maxCapacity + return builder.maxCapacity } // ldapAuthImpl gives the internal utilities of authentication with LDAP. @@ -448,8 +448,7 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { if retryCount >= getConnectionMaxRetry { return nil, errors.Wrap(err, "fail to bind to anonymous user") } - // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here. - // TODO: refactor the `RWLock` to avoid the problem of holding the lock. + time.Sleep(getConnectionRetryInterval) continue }