Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: extract listener config #948

Merged
merged 2 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions bind/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ func HTTPProxyConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPProxyConfig, lcfg *lo
"<name>"+
"If the header is present in the request, "+
"the proxy will associate the value with the request in the logs. ")

fs.Var(&cfg.ReadLimit, "read-limit", "<bandwidth>"+
"Global read rate limit in bytes per second i.e. how many bytes per second you can receive from a proxy. "+
"Accepts binary format (e.g. 1.5Ki, 1Mi, 3.6Gi). ")

fs.Var(&cfg.WriteLimit, "write-limit", "<bandwidth>"+
"Global write rate limit in bytes per second i.e. how many bytes per second you can send to proxy. "+
"Accepts binary format (e.g. 1.5Ki, 1Mi, 3.6Gi). ")
}

func DenyDomains(fs *pflag.FlagSet, cfg *[]ruleset.RegexpListItem) {
Expand Down Expand Up @@ -283,16 +275,13 @@ func TLSClientConfig(fs *pflag.FlagSet, cfg *forwarder.TLSClientConfig) {
}

func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix string, schemes ...forwarder.Scheme) {
ListenerConfig(fs, &cfg.ListenerConfig, prefix)

namePrefix := prefix
if namePrefix != "" {
namePrefix += "-"
}

fs.StringVarP(&cfg.Addr,
namePrefix+"address", "", cfg.Addr, "<host:port>"+
"The server address to listen on. "+
"If the host is empty, the server will listen on all available interfaces. ")

if schemes == nil {
schemes = []forwarder.Scheme{
forwarder.HTTPScheme,
Expand Down Expand Up @@ -335,6 +324,26 @@ func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix
"Basic authentication credentials to protect the server. ")
}

func ListenerConfig(fs *pflag.FlagSet, cfg *forwarder.ListenerConfig, prefix string) {
namePrefix := prefix
if namePrefix != "" {
namePrefix += "-"
}

fs.StringVarP(&cfg.Address,
namePrefix+"address", "", cfg.Address, "<host:port>"+
"The server address to listen on. "+
"If the host is empty, the server will listen on all available interfaces. ")

fs.Var(&cfg.ReadLimit, namePrefix+"read-limit", "<bandwidth>"+
"Global read rate limit in bytes per second i.e. how many bytes per second you can receive from a proxy. "+
"Accepts binary format (e.g. 1.5Ki, 1Mi, 3.6Gi). ")

fs.Var(&cfg.WriteLimit, namePrefix+"write-limit", "<bandwidth>"+
"Global write rate limit in bytes per second i.e. how many bytes per second you can send to proxy. "+
"Accepts binary format (e.g. 1.5Ki, 1Mi, 3.6Gi). ")
}

func HTTPLogConfig(fs *pflag.FlagSet, cfg []NamedParam[httplog.Mode]) {
for _, p := range cfg {
if p.Param == nil {
Expand Down
4 changes: 2 additions & 2 deletions command/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (c *command) runE(cmd *cobra.Command, _ []string) (cmdErr error) {
})
}

if c.apiServerConfig.Addr != "" {
if c.apiServerConfig.Address != "" {
a, err := forwarder.NewHTTPServer(c.apiServerConfig, h, logger.Named("api"))
if err != nil {
return err
Expand Down Expand Up @@ -492,7 +492,7 @@ func makeCommand() command {
c.httpTransportConfig.PromNamespace = promNs
c.httpProxyConfig.PromRegistry = c.promReg
c.httpProxyConfig.PromNamespace = promNs
c.apiServerConfig.Addr = "localhost:10000"
c.apiServerConfig.Address = "localhost:10000"

return c
}
Expand Down
41 changes: 17 additions & 24 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,20 @@ var ErrConnectFallback = martian.ErrConnectFallback

type HTTPProxyConfig struct {
HTTPServerConfig
Name string
MITM *MITMConfig
MITMDomains Matcher
ProxyLocalhost ProxyLocalhostMode
UpstreamProxy *url.URL
UpstreamProxyFunc ProxyFunc
DenyDomains Matcher
DirectDomains Matcher
RequestIDHeader string
RequestModifiers []RequestModifier
ResponseModifiers []ResponseModifier
ConnectFunc ConnectFunc
ConnectTimeout time.Duration
ProxyProtocolConfig *ProxyProtocolConfig
ReadLimit SizeSuffix
WriteLimit SizeSuffix
PromHTTPOpts []middleware.PrometheusOpt
Name string
MITM *MITMConfig
MITMDomains Matcher
ProxyLocalhost ProxyLocalhostMode
UpstreamProxy *url.URL
UpstreamProxyFunc ProxyFunc
DenyDomains Matcher
DirectDomains Matcher
RequestIDHeader string
RequestModifiers []RequestModifier
ResponseModifiers []ResponseModifier
ConnectFunc ConnectFunc
ConnectTimeout time.Duration
PromHTTPOpts []middleware.PrometheusOpt

// TestingHTTPHandler uses Martian's [http.Handler] implementation
// over [http.Server] instead of the default TCP server.
Expand All @@ -106,8 +103,8 @@ type HTTPProxyConfig struct {
func DefaultHTTPProxyConfig() *HTTPProxyConfig {
return &HTTPProxyConfig{
HTTPServerConfig: HTTPServerConfig{
ListenerConfig: *DefaultListenerConfig(":3128"),
Protocol: HTTPScheme,
Addr: ":3128",
IdleTimeout: 1 * time.Hour,
ReadHeaderTimeout: 1 * time.Minute,
TLSServerConfig: TLSServerConfig{
Expand Down Expand Up @@ -554,12 +551,8 @@ func (hp *HTTPProxy) listen() (net.Listener, error) {
}

l := Listener{
Address: hp.config.Addr,
Log: hp.log,
ProxyProtocolConfig: hp.config.ProxyProtocolConfig,
TLSConfig: hp.tlsConfig,
ReadLimit: int64(hp.config.ReadLimit),
WriteLimit: int64(hp.config.WriteLimit),
ListenerConfig: hp.config.ListenerConfig,
TLSConfig: hp.tlsConfig,
PromConfig: PromConfig{
PromNamespace: hp.config.PromNamespace,
PromRegistry: hp.config.PromRegistry,
Expand Down
14 changes: 8 additions & 6 deletions http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func h2TLSConfigTemplate() *tls.Config {
}

type HTTPServerConfig struct {
ListenerConfig
Protocol Scheme
Addr string
TLSServerConfig
IdleTimeout time.Duration
ReadTimeout time.Duration
Expand All @@ -87,8 +87,8 @@ type HTTPServerConfig struct {

func DefaultHTTPServerConfig() *HTTPServerConfig {
return &HTTPServerConfig{
ListenerConfig: *DefaultListenerConfig(":8080"),
Protocol: HTTPScheme,
Addr: ":8080",
IdleTimeout: 1 * time.Hour,
ReadHeaderTimeout: 1 * time.Minute,
}
Expand Down Expand Up @@ -119,7 +119,6 @@ func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTP
config: *cfg,
log: log,
srv: &http.Server{
Addr: cfg.Addr,
Handler: withMiddleware(cfg, log, h),
IdleTimeout: cfg.IdleTimeout,
ReadTimeout: cfg.ReadTimeout,
Expand Down Expand Up @@ -232,11 +231,14 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
func (hs *HTTPServer) listen() (net.Listener, error) {
switch hs.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
listener, err := Listen("tcp", hs.srv.Addr)
if err != nil {
l := Listener{
ListenerConfig: hs.config.ListenerConfig,
PromConfig: hs.config.PromConfig,
}
if err := l.Listen(); err != nil {
return nil, fmt.Errorf("failed to open listener on address %s: %w", hs.srv.Addr, err)
}
return listener, nil
return &l, nil
default:
return nil, fmt.Errorf("invalid protocol %q", hs.config.Protocol)
}
Expand Down
22 changes: 15 additions & 7 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"time"

"github.com/saucelabs/forwarder/conntrack"
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/proxyproto"
"github.com/saucelabs/forwarder/ratelimit"
)
Expand Down Expand Up @@ -178,14 +177,23 @@ func DefaultProxyProtocolConfig() *ProxyProtocolConfig {
}
}

type Listener struct {
type ListenerConfig struct {
Address string
Log log.Logger
TLSConfig *tls.Config
ProxyProtocolConfig *ProxyProtocolConfig
ReadLimit int64
WriteLimit int64
ReadLimit SizeSuffix
WriteLimit SizeSuffix
TrackTraffic bool
}

func DefaultListenerConfig(addr string) *ListenerConfig {
return &ListenerConfig{
Address: addr,
}
}

type Listener struct {
ListenerConfig
TLSConfig *tls.Config
PromConfig

listener net.Listener
Expand All @@ -210,7 +218,7 @@ func (l *Listener) Listen() error {
}

if rl, wl := l.ReadLimit, l.WriteLimit; rl > 0 || wl > 0 {
ll = ratelimit.NewListener(ll, rl, wl)
ll = ratelimit.NewListener(ll, int64(rl), int64(wl))
}

l.listener = ll
Expand Down
28 changes: 12 additions & 16 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ import (

"github.com/prometheus/client_golang/prometheus"
"github.com/saucelabs/forwarder/conntrack"
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/utils/certutil"
"github.com/saucelabs/forwarder/utils/golden"
)

var testListenerConfig = ListenerConfig{
Address: "localhost:0",
}

func TestDialRedirectFromHostPortPairs(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -91,8 +94,7 @@ func TestDialRedirectFromHostPortPairs(t *testing.T) {

func TestDialerRedirect(t *testing.T) {
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
}
defer l.Close()

Expand Down Expand Up @@ -145,8 +147,7 @@ func TestDialerMetrics(t *testing.T) {
}

l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
}
defer l.Close()

Expand Down Expand Up @@ -248,8 +249,7 @@ func (l *Listener) acceptAndCopy() {

func TestListenerListenOnce(t *testing.T) {
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
}
defer l.Close()

Expand All @@ -263,8 +263,7 @@ func TestListenerListenOnce(t *testing.T) {
func TestListenerMetricsAccepted(t *testing.T) {
r := prometheus.NewRegistry()
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
Expand Down Expand Up @@ -293,9 +292,8 @@ func TestListenerMetricsAccepted(t *testing.T) {
func TestListenerMetricsAcceptedWithTLS(t *testing.T) {
r := prometheus.NewRegistry()
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
TLSConfig: selfSingedCert(),
ListenerConfig: testListenerConfig,
TLSConfig: selfSingedCert(),
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
Expand Down Expand Up @@ -325,8 +323,7 @@ func TestListenerMetricsAcceptedWithTLS(t *testing.T) {
func TestListenerMetricsClosed(t *testing.T) {
r := prometheus.NewRegistry()
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
Expand Down Expand Up @@ -364,8 +361,7 @@ func (l errListener) Accept() (net.Conn, error) {
func TestListenerMetricsErrors(t *testing.T) {
r := prometheus.NewRegistry()
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
ListenerConfig: testListenerConfig,
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
Expand Down
4 changes: 0 additions & 4 deletions proxyproto/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ type Conn struct {
headerErr error
}

func (c *Conn) NetConn() net.Conn {
return c.Conn
}

func (c *Conn) LocalAddr() net.Addr {
if err := c.readHeader(); err != nil {
return c.Conn.LocalAddr()
Expand Down