Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check connection unreachable host #142

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,50 @@ func TestReconnect(t *testing.T) {
}
}

func TestUnreachableHost(t *testing.T) {
hostList := poolAddress

timeoutConfig := PoolConfig{
TimeOut: 0 * time.Millisecond,
IdleTime: 0 * time.Millisecond,
MaxConnPoolSize: 6,
MinConnPoolSize: 0,
}

// Initialize connectin pool
pool, err := NewConnectionPool(hostList, timeoutConfig, nebulaLog)
if err != nil {
t.Fatalf("fail to initialize the connection pool, host: %s, port: %d, %s", address, port, err.Error())
}
defer pool.Close()

// simulate the host is unreachable
pool.addresses[0] = HostAddress{Host: "192.192.192.1", Port: 9669}

var sessionList []*Session
c := make(chan bool)
go func() {
// at least 6 seconds because of 2 * 3 timeout seconds.
for i := 0; i < 4; i++ {
session, err := pool.GetSession(username, password)
if err != nil {
t.Errorf("fail to create a new session from connection pool, %s", err.Error())
}
sessionList = append(sessionList, session)
}
for _, session := range sessionList {
session.Release()
}
c <- true
}()
select {
case <-c:
return
case <-time.After(15 * time.Second):
t.Fatal("could not catch the error")
}
}

func TestIpLookup(t *testing.T) {
hostAddress := HostAddress{Host: "192.168.10.105", Port: 3699}
hostList := []HostAddress{hostAddress}
Expand Down
24 changes: 23 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"crypto/tls"
"fmt"
"math"
"net"
"strconv"
"time"

"github.com/facebook/fbthrift/thrift/lib/go/thrift"
Expand All @@ -33,11 +35,31 @@ func newConnection(severAddress HostAddress) *connection {
}
}

func (cn *connection) check(hostAddress HostAddress, timeout time.Duration) error {
var defaultTimeout time.Duration = 3
if timeout == 0 || timeout > defaultTimeout {
timeout = defaultTimeout
}

host, port := hostAddress.Host, hostAddress.Port
addr := net.JoinHostPort(host, strconv.Itoa(port))
conn, err := net.DialTimeout("tcp", addr, timeout*time.Second)
defer func() {
if conn != nil {
conn.Close()
}
}()
return err
}

func (cn *connection) open(hostAddress HostAddress, timeout time.Duration) error {
return cn.openSSL(hostAddress, timeout, nil)
}
Comment on lines 55 to 57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check open() as well as openSSL()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

close this pr, do not check it in open method, as this would open two times.
will check it in the initialization.


func (cn *connection) openSSL(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config) error {
if err := cn.check(hostAddress, timeout); err != nil {
return fmt.Errorf("failed to open transport, error: %s", err.Error())
}
ip := hostAddress.Host
port := hostAddress.Port
newAdd := fmt.Sprintf("%s:%d", ip, port)
Expand Down Expand Up @@ -77,7 +99,7 @@ func (cn *connection) verifyClientVersion() error {
return fmt.Errorf("failed to verify client version: %s", err.Error())
}
if resp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
return fmt.Errorf("incompatible version between client and server: %s.", string(resp.GetErrorMsg()))
return fmt.Errorf("incompatible version between client and server: %s", string(resp.GetErrorMsg()))
}
return nil
}
Expand Down