diff --git a/backend/pool.go b/backend/pool.go index 758e33b..dcb4285 100644 --- a/backend/pool.go +++ b/backend/pool.go @@ -1,6 +1,7 @@ package backend import ( + "errors" "net" "time" @@ -8,21 +9,32 @@ import ( ) const maxConn int = 10 +const maxOverflow int = 10 +const maxConnWait time.Duration = 10 * time.Millisecond + +// Errors +var ErrTimeout = errors.New("timeout waiting to build connection") type Pool struct { host string connections chan (net.Conn) createsem chan (bool) + mkConn func(host string) (net.Conn, error) } func NewPool(host string) *Pool { return &Pool{ host: host, connections: make(chan (net.Conn), maxConn), - createsem: make(chan (bool), 1), + createsem: make(chan (bool), maxConn+maxOverflow), + mkConn: defaultMkConn, } } +func defaultMkConn(host string) (net.Conn, error) { + return net.Dial("tcp", host) +} + func prepareConnection(conn net.Conn) (net.Conn, error) { if err := conn.SetWriteDeadline(time.Now().Add(60 * time.Second)); err != nil { return nil, err @@ -48,28 +60,42 @@ func (cp *Pool) Get() (net.Conn, error) { case cp.createsem <- true: // Room to make a connection log.Debugf("About to connect") - conn, err := net.Dial("tcp", cp.host) + conn, err := cp.mkConn(cp.host) if err != nil { // On error, release our create hold - <-cp.createsem + cp.release(conn) return nil, err } - return prepareConnection(conn) + conn, err = prepareConnection(conn) + if err != nil { + // On error, release our create hold + cp.release(conn) + return nil, err + } + return conn, err + case <-time.After(maxConnWait): + log.Debugf("Max connection exceeded") + return nil, ErrTimeout } } } -func (cp *Pool) Return(c net.Conn, failed bool) { +func (cp *Pool) release(conn net.Conn) { + <-cp.createsem + if conn != nil { + conn.Close() + } +} +func (cp *Pool) Return(conn net.Conn, failed bool) { if failed { - <-cp.createsem + cp.release(conn) return } select { - case cp.connections <- c: + case cp.connections <- conn: default: // Overflow connection. - <-cp.createsem - c.Close() + cp.release(conn) } } diff --git a/backend/pool_test.go b/backend/pool_test.go new file mode 100644 index 0000000..a6e0c21 --- /dev/null +++ b/backend/pool_test.go @@ -0,0 +1,154 @@ +package backend + +import ( + "errors" + "net" + "testing" + "time" +) + +// Errors +var ErrTestConnectionCreation = errors.New("connection creation error") +var ErrTestClose = errors.New("close error") +var ErrTestSetWriteDeadline = errors.New("set write deadline error") + +type TestConn struct { + failOnSetWriteDeadline bool + failOnClose bool +} + +func (t TestConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +func (t TestConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func (t TestConn) Close() error { + if t.failOnClose { + return ErrTestClose + } + return nil +} + +func (t TestConn) LocalAddr() net.Addr { + return nil +} + +func (t TestConn) RemoteAddr() net.Addr { + return nil +} + +func (t TestConn) SetDeadline(ti time.Time) error { + return nil +} + +func (t TestConn) SetReadDeadline(ti time.Time) error { + return nil +} + +func (t TestConn) SetWriteDeadline(ti time.Time) error { + if t.failOnSetWriteDeadline { + return ErrTestSetWriteDeadline + } + return nil +} + +func testMkGoodConn(host string) (net.Conn, error) { + return &TestConn{failOnSetWriteDeadline: false, failOnClose: false}, nil +} + +func testMkConnSetDeadlineFailure(host string) (net.Conn, error) { + return &TestConn{failOnSetWriteDeadline: true, failOnClose: false}, nil +} + +func testMkConnCloseFailure(host string) (net.Conn, error) { + return &TestConn{failOnSetWriteDeadline: false, failOnClose: true}, nil +} + +func testMkConnFailure(host string) (net.Conn, error) { + return nil, ErrTestConnectionCreation +} +func TestConnPool(t *testing.T) { + cp := NewPool("somehost") + cp.mkConn = testMkGoodConn + seenConns := map[net.Conn]bool{} + + // able to get upto maxconn+maxoverflow + for i := 0; i < maxConn+maxOverflow; i++ { + sc, err := cp.Get() + if err != nil { + t.Fatalf("Error getting connection from pool: %v", err) + } + seenConns[sc] = true + } + // connection pool should be empty now and overflow should be maxxed out + assertConnPoolState(cp, t, 0, maxConn+maxOverflow) + + // trying to get more connection should fail + _, err := cp.Get() + if ErrTimeout != err { + t.Errorf("Expected %v but got %v", ErrTimeout, err) + } + assertConnPoolState(cp, t, 0, maxConn+maxOverflow) + + // releasing all acquired connections should fill up the connection pool + for k := range seenConns { + cp.Return(k, false) + } + assertConnPoolState(cp, t, maxConn, maxOverflow) + + // connections should now be reused + reusedConn, err := cp.Get() + if err != nil { + t.Fatalf("Error getting connection from pool: %v", err) + } + if _, exists := seenConns[reusedConn]; !exists { + t.Fatalf("Was expecting connection reuse") + } + assertConnPoolState(cp, t, maxConn-1, maxOverflow) +} + +func assertConnPoolState(cp *Pool, t *testing.T, expectedPoolCount int, expectedSemCount int) { + if (len(cp.connections) != expectedPoolCount) || (len(cp.createsem) != expectedSemCount) { + t.Fatalf("expected %v connections in the pool and %v as the semaphoreCount, but got %v and %v respectively", + expectedPoolCount, expectedSemCount, len(cp.connections), len(cp.createsem)) + } +} + +func assertErrorType(t *testing.T, expectedError error, gotError error) { + if expectedError != gotError { + t.Fatalf("was expecting %v but got %v", expectedError, gotError) + } +} + +func TestConnPoolFailures(t *testing.T) { + cp := NewPool("somehost") + + cp.mkConn = testMkConnFailure + _, err := cp.Get() + assertErrorType(t, ErrTestConnectionCreation, err) + assertConnPoolState(cp, t, 0, 0) + + cp.mkConn = testMkConnSetDeadlineFailure + _, err = cp.Get() + assertErrorType(t, ErrTestSetWriteDeadline, err) + assertConnPoolState(cp, t, 0, 0) + + cp.mkConn = testMkConnCloseFailure + conn, err := cp.Get() + assertErrorType(t, nil, err) + assertConnPoolState(cp, t, 0, 1) + cp.Return(conn, false) + assertConnPoolState(cp, t, 1, 1) + + // + cp.mkConn = testMkGoodConn + conn, err = cp.Get() + assertErrorType(t, nil, nil) + assertConnPoolState(cp, t, 0, 1) + cp.Return(conn, true) + assertConnPoolState(cp, t, 0, 0) + +}