Skip to content

Commit

Permalink
[v8] Make TestLimiter less flaky (#13526) (#14372)
Browse files Browse the repository at this point in the history
Make TestLimiter less flaky
  • Loading branch information
EdwardDowling authored Jul 21, 2022
1 parent e75724b commit 62f9bef
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
11 changes: 11 additions & 0 deletions lib/limiter/connlimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,14 @@ func (l *ConnectionsLimiter) ReleaseConnection(token string) {
}
}
}

// GetNumConnections returns the current number of connections for a token
func (l *ConnectionsLimiter) GetNumConnection(token string) (int64, error) {
l.Lock()
defer l.Unlock()
numberOfConnections, exists := l.connections[token]
if !exists {
return -1, trace.Errorf("Trying to get connections of a nonexistent token: %s", token)
}
return numberOfConnections, nil
}
59 changes: 48 additions & 11 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/utils"
"github.com/mailgun/timetools"

"github.com/gravitational/trace"

Expand Down Expand Up @@ -1403,13 +1404,46 @@ func TestClientDisconnect(t *testing.T) {
require.NoError(t, clt.Close())
}

func getWaitForNumberOfConnsFunc(t *testing.T, limiter *limiter.Limiter, token string, num int64) func() bool {
return func() bool {
connNumber, err := limiter.GetNumConnection(token)
require.NoError(t, err)
return connNumber == num
}
}

// fakeClock is a wrapper around clockwork.FakeClock that satisfies the timetoools.TimeProvider interface.
// We are wrapping this so we can use the same mocked clock across the server and rate limiter.
type fakeClock struct {
clock clockwork.FakeClock
}

func (fc fakeClock) UtcNow() time.Time {
return fc.clock.Now().UTC()
}

func (fc fakeClock) Sleep(d time.Duration) {
fc.clock.Advance(d)
}

func (fc fakeClock) After(d time.Duration) <-chan time.Time {
return fc.clock.After(d)
}

var _ timetools.TimeProvider = (*fakeClock)(nil)

func TestLimiter(t *testing.T) {
t.Parallel()
f := newFixture(t)
ctx := context.Background()

fClock := &fakeClock{
clock: f.clock,
}

limiter, err := limiter.NewLimiter(
limiter.Config{
Clock: fClock,
MaxConnections: 2,
Rates: []limiter.Rate{
{
Expand Down Expand Up @@ -1454,8 +1488,6 @@ func TestLimiter(t *testing.T) {
require.NoError(t, auth.CreateUploaderDir(nodeStateDir))
defer srv.Close()

// maxConnection = 3
// current connections = 1 (one connection is opened from SetUpTest)
config := &ssh.ClientConfig{
User: f.user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(f.up.certSigner)},
Expand All @@ -1470,39 +1502,44 @@ func TestLimiter(t *testing.T) {
require.NoError(t, err)
require.NoError(t, se0.Shell())

// current connections = 2
// current connections = 1
clt, err := ssh.Dial("tcp", srv.Addr(), config)
require.NotNil(t, clt)
require.NoError(t, err)
require.NotNil(t, clt)
se, err := clt.NewSession()
require.NoError(t, err)
require.NoError(t, se.Shell())

// current connections = 3
// current connections = 2
_, err = ssh.Dial("tcp", srv.Addr(), config)
require.Error(t, err)

require.NoError(t, se.Close())
se.Wait()
require.NoError(t, clt.Close())
time.Sleep(50 * time.Millisecond)
require.ErrorIs(t, clt.Wait(), net.ErrClosed)

// current connections = 2
require.Eventually(t, getWaitForNumberOfConnsFunc(t, limiter, "127.0.0.1", 1), time.Second*10, time.Millisecond*100)

// current connections = 1
clt, err = ssh.Dial("tcp", srv.Addr(), config)
require.NotNil(t, clt)
require.NoError(t, err)
se, err = clt.NewSession()
require.NoError(t, err)
require.NoError(t, se.Shell())

// current connections = 3
// current connections = 2
_, err = ssh.Dial("tcp", srv.Addr(), config)
require.Error(t, err)

require.NoError(t, se.Close())
se.Wait()
require.NoError(t, clt.Close())
time.Sleep(50 * time.Millisecond)
require.ErrorIs(t, clt.Wait(), net.ErrClosed)

// current connections = 2
require.Eventually(t, getWaitForNumberOfConnsFunc(t, limiter, "127.0.0.1", 1), time.Second*10, time.Millisecond*100)

// current connections = 1
// requests rate should exceed now
clt, err = ssh.Dial("tcp", srv.Addr(), config)
require.NotNil(t, clt)
Expand Down

0 comments on commit 62f9bef

Please sign in to comment.