diff --git a/dial.go b/dial.go index c10aef61b..b2ac8338c 100644 --- a/dial.go +++ b/dial.go @@ -21,20 +21,15 @@ type DialConfig struct { // often around 3 minutes. DialTimeout time.Duration - // KeepAlive specifies the interval between keep-alive - // probes for an active network connection. - // If zero, keep-alive probes are sent with a default value - // (currently 15 seconds), if supported by the protocol and operating - // system. Network protocols or operating systems that do - // not support keep-alives ignore this field. - // If negative, keep-alive probes are disabled. - KeepAlive time.Duration + // KeepAlive enables TCP keep-alive probes for an active network connection. + // The keep-alive probes are sent with OS specific intervals. + KeepAlive bool } func DefaultDialConfig() *DialConfig { return &DialConfig{ DialTimeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + KeepAlive: true, } } @@ -43,15 +38,21 @@ type Dialer struct { } func NewDialer(cfg *DialConfig) *Dialer { - return &Dialer{ - net.Dialer{ - Timeout: cfg.DialTimeout, - KeepAlive: cfg.KeepAlive, - Resolver: &net.Resolver{ - PreferGo: true, - }, + nd := net.Dialer{ + Timeout: cfg.DialTimeout, + KeepAlive: -1, + Resolver: &net.Resolver{ + PreferGo: true, }, } + + if cfg.KeepAlive { + enableTCPKeepAlive(&nd) + } + + return &Dialer{ + nd: nd, + } } func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { diff --git a/dial_unix.go b/dial_unix.go new file mode 100644 index 000000000..f99cf38ba --- /dev/null +++ b/dial_unix.go @@ -0,0 +1,27 @@ +// Copyright 2023 Sauce Labs Inc., all rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//go:build unix + +package forwarder + +import ( + "fmt" + "net" + "os" + "syscall" +) + +func enableTCPKeepAlive(d *net.Dialer) { + d.KeepAlive = -1 + d.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, 1); err != nil { + fmt.Fprintf(os.Stderr, "failed to set SO_KEEPALIVE: %v\n", err) + } + }) + } +} diff --git a/dial_windows.go b/dial_windows.go new file mode 100644 index 000000000..b39b55eb0 --- /dev/null +++ b/dial_windows.go @@ -0,0 +1,31 @@ +// Copyright 2023 Sauce Labs Inc., all rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//go:build windows + +package forwarder + +import ( + "fmt" + "net" + "os" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +func enableTCPKeepAlive(d *net.Dialer) { + d.KeepAlive = -1 + d.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + optval := (*byte)(unsafe.Pointer(&[4]byte{1})) + if err := windows.Setsockopt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, optval, 1); err != nil { + fmt.Fprintf(os.Stderr, "failed to set SO_KEEPALIVE: %v\n", err) + } + }) + } +}