Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional bandwidth tracker #114

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions bandwidth/nope_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package bandwidth

import (
"context"
"net"
)

type NopeTracker struct {
}

func (bt *NopeTracker) GetWriteBytes() int64 {
return 0
}

func (bt *NopeTracker) GetReadBytes() int64 {
return 0
}

func (bt *NopeTracker) GetTotalBandwidth() int64 {
return 0
}

func (bt *NopeTracker) TrackConnection(ctx context.Context, conn net.Conn) net.Conn {
return conn
}

func NewNopeTracker() *NopeTracker {
return &NopeTracker{}
}

var _ BandwidthTracker = (*NopeTracker)(nil)
29 changes: 29 additions & 0 deletions bandwidth/tracked_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package bandwidth

import (
"net"
)

type BTConn struct {
net.Conn
tracker *Tracker
}

func (bt *BTConn) Read(p []byte) (n int, err error) {
n, err = bt.Conn.Read(p)
bt.tracker.addReadBytes(int64(n))
return n, err
}

func (bt *BTConn) Write(p []byte) (n int, err error) {
n, err = bt.Conn.Write(p)
bt.tracker.addWriteBytes(int64(n))
return n, err
}

func newTrackedConn(conn net.Conn, tracker *Tracker) *BTConn {
return &BTConn{
Conn: conn,
tracker: tracker,
}
}
49 changes: 49 additions & 0 deletions bandwidth/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package bandwidth

import (
"context"
"net"
"sync/atomic"
)

type BandwidthTracker interface {
GetTotalBandwidth() int64
GetWriteBytes() int64
GetReadBytes() int64
TrackConnection(ctx context.Context, conn net.Conn) net.Conn
}

type Tracker struct {
writeBytes atomic.Int64
readBytes atomic.Int64
}

func (bt *Tracker) GetWriteBytes() int64 {
return bt.writeBytes.Load()
}

func (bt *Tracker) GetReadBytes() int64 {
return bt.readBytes.Load()
}

func (bt *Tracker) GetTotalBandwidth() int64 {
return bt.readBytes.Load() + bt.writeBytes.Load()
}

func (bt *Tracker) TrackConnection(ctx context.Context, conn net.Conn) net.Conn {
return newTrackedConn(conn, bt)
}

func (bt *Tracker) addWriteBytes(n int64) {
bt.writeBytes.Add(n)
}

func (bt *Tracker) addReadBytes(n int64) {
bt.readBytes.Add(n)
}

func NewTracker() *Tracker {
return &Tracker{}
}

var _ BandwidthTracker = (*Tracker)(nil)
67 changes: 0 additions & 67 deletions bandwidth_tracker.go

This file was deleted.

19 changes: 12 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"sync"
"time"

"github.com/bogdanfinn/tls-client/profiles"

http "github.com/bogdanfinn/fhttp"
"github.com/bogdanfinn/fhttp/httputil"
"github.com/bogdanfinn/tls-client/bandwidth"
"github.com/bogdanfinn/tls-client/profiles"
"golang.org/x/net/proxy"
)

Expand All @@ -36,7 +36,7 @@ type HttpClient interface {
Head(url string) (resp *http.Response, err error)
Post(url, contentType string, body io.Reader) (resp *http.Response, err error)

GetBandwidthTracker() BandwidthTracker
GetBandwidthTracker() bandwidth.BandwidthTracker
}

// Interface guards are a cheap way to make sure all methods are implemented, this is a static check and does not affect runtime performance.
Expand All @@ -48,7 +48,7 @@ type httpClient struct {
logger Logger
config *httpClientConfig

bandwidthTracker *bandwidthTracker
bandwidthTracker bandwidth.BandwidthTracker
}

var DefaultTimeoutSeconds = 30
Expand Down Expand Up @@ -117,7 +117,7 @@ func validateConfig(_ *httpClientConfig) error {
return nil
}

func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, *bandwidthTracker, profiles.ClientProfile, error) {
func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, bandwidth.BandwidthTracker, profiles.ClientProfile, error) {
var dialer proxy.ContextDialer
dialer = newDirectDialer(config.timeout, config.localAddr, config.dialer)

Expand All @@ -141,7 +141,12 @@ func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, *ba
}
}

bandwidthTracker := newBandwidthTracker()
var bandwidthTracker bandwidth.BandwidthTracker
if config.enabledBandwidthTracker {
bandwidthTracker = bandwidth.NewTracker()
} else {
bandwidthTracker = bandwidth.NewNopeTracker()
}

clientProfile := config.clientProfile

Expand Down Expand Up @@ -281,7 +286,7 @@ func (c *httpClient) GetCookieJar() http.CookieJar {
}

// GetBandwidthTracker returns the bandwidth tracker
func (c *httpClient) GetBandwidthTracker() BandwidthTracker {
func (c *httpClient) GetBandwidthTracker() bandwidth.BandwidthTracker {
return c.bandwidthTracker
}

Expand Down
11 changes: 9 additions & 2 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"net"
"time"

"github.com/bogdanfinn/tls-client/profiles"

http "github.com/bogdanfinn/fhttp"
"github.com/bogdanfinn/tls-client/profiles"
)

type HttpClientOption func(config *httpClientConfig)
Expand Down Expand Up @@ -59,6 +58,8 @@ type httpClientConfig struct {
// Establish a connection to origin server via ipv4 only
disableIPV6 bool
dialer net.Dialer

enabledBandwidthTracker bool
}

// WithProxyUrl configures a HTTP client to use the specified proxy URL.
Expand Down Expand Up @@ -241,3 +242,9 @@ func WithDisableIPV6() HttpClientOption {
config.disableIPV6 = true
}
}

func WithBandwidthTracker() HttpClientOption {
return func(config *httpClientConfig) {
config.enabledBandwidthTracker = true
}
}
14 changes: 7 additions & 7 deletions roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"sync"
"time"

"github.com/bogdanfinn/tls-client/profiles"

http "github.com/bogdanfinn/fhttp"
"github.com/bogdanfinn/fhttp/http2"
"github.com/bogdanfinn/tls-client/bandwidth"
"github.com/bogdanfinn/tls-client/profiles"
tls "github.com/bogdanfinn/utls"
"golang.org/x/net/proxy"
)
Expand All @@ -36,7 +36,7 @@ type roundTripper struct {

forceHttp1 bool

bandwidthTracker *bandwidthTracker
bandwidthTracker bandwidth.BandwidthTracker

headerPriority *http2.PriorityParam
clientSessionCache tls.ClientSessionCache
Expand Down Expand Up @@ -100,7 +100,7 @@ func (rt *roundTripper) getTransport(req *http.Request, addr string) error {
return fmt.Errorf("invalid URL scheme: [%v]", req.URL.Scheme)
}

_, err := rt.dialTLS(context.Background(), "tcp", addr)
MashinaMashina marked this conversation as resolved.
Show resolved Hide resolved
_, err := rt.dialTLS(req.Context(), "tcp", addr)
switch err {
case errProtocolNegotiated:
case nil:
Expand Down Expand Up @@ -149,9 +149,9 @@ func (rt *roundTripper) dialTLS(ctx context.Context, network, addr string) (net.
tlsConfig.KeyLogWriter = rt.transportOptions.KeyLogWriter
}

trackedConn := newBandwidthTrackedConn(rawConn, rt.bandwidthTracker)
rawConn = rt.bandwidthTracker.TrackConnection(ctx, rawConn)

conn := tls.UClient(trackedConn, tlsConfig, rt.clientHelloId, rt.withRandomTlsExtensionOrder, rt.forceHttp1)
conn := tls.UClient(rawConn, tlsConfig, rt.clientHelloId, rt.withRandomTlsExtensionOrder, rt.forceHttp1)
if err = conn.HandshakeContext(ctx); err != nil {
_ = conn.Close()

Expand Down Expand Up @@ -311,7 +311,7 @@ func (rt *roundTripper) getDialTLSAddr(req *http.Request) string {
return net.JoinHostPort(req.URL.Host, "443")
}

func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, bandwidthTracker *bandwidthTracker, dialer ...proxy.ContextDialer) (http.RoundTripper, error) {
func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, bandwidthTracker bandwidth.BandwidthTracker, dialer ...proxy.ContextDialer) (http.RoundTripper, error) {
pinner, err := NewCertificatePinner(certificatePins)
if err != nil {
return nil, fmt.Errorf("can not instantiate certificate pinner: %w", err)
Expand Down