Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bandwidth tracker feature #113

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions bandwidth_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package tls_client

import (
"net"
"sync/atomic"
)

type btConn struct {
net.Conn
tracker *bandwidthTracker
}

func (bt *btConn) Read(p []byte) (n int, err error) {
n, err = bt.Conn.Read(p)
bt.tracker.addReadBytes(int64(n))
return n, err
}

func (bt *btConn) Write(p []byte) (n int, err error) {
n, err = bt.Conn.Write(p)
bt.tracker.addWriteBytes(int64(n))
return n, err
}

func newBandwidthTrackedConn(conn net.Conn, tracker *bandwidthTracker) *btConn {
return &btConn{
Conn: conn,
tracker: tracker,
}
}

type BandwidthTracker interface {
GetTotalBandwidth() int64
GetWriteBytes() int64
GetReadBytes() int64
}

type bandwidthTracker struct {
writeBytes atomic.Int64
readBytes atomic.Int64
}

func (bt *bandwidthTracker) GetWriteBytes() int64 {
return bt.writeBytes.Load()
}

func (bt *bandwidthTracker) GetReadBytes() int64 {
return bt.readBytes.Load()
}

func (bt *bandwidthTracker) GetTotalBandwidth() int64 {
return bt.readBytes.Load() + bt.writeBytes.Load()
}

func (bt *bandwidthTracker) addWriteBytes(n int64) {
bt.writeBytes.Add(n)
}

func (bt *bandwidthTracker) addReadBytes(n int64) {
bt.readBytes.Add(n)
}

func newBandwidthTracker() *bandwidthTracker {
return &bandwidthTracker{}
}

var _ BandwidthTracker = (*bandwidthTracker)(nil)
34 changes: 23 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type HttpClient interface {
Get(url string) (resp *http.Response, err error)
Head(url string) (resp *http.Response, err error)
Post(url, contentType string, body io.Reader) (resp *http.Response, err error)

GetBandwidthTracker() BandwidthTracker
}

// Interface guards are a cheap way to make sure all methods are implemented, this is a static check and does not affect runtime performance.
Expand All @@ -45,6 +47,8 @@ type httpClient struct {
headerLck sync.Mutex
logger Logger
config *httpClientConfig

bandwidthTracker *bandwidthTracker
}

var DefaultTimeoutSeconds = 30
Expand Down Expand Up @@ -81,7 +85,7 @@ func NewHttpClient(logger Logger, options ...HttpClientOption) (HttpClient, erro
return nil, err
}

client, clientProfile, err := buildFromConfig(logger, config)
client, bandwidthTracker, clientProfile, err := buildFromConfig(logger, config)
if err != nil {
return nil, err
}
Expand All @@ -101,25 +105,26 @@ func NewHttpClient(logger Logger, options ...HttpClientOption) (HttpClient, erro
}

return &httpClient{
Client: *client,
logger: logger,
config: config,
headerLck: sync.Mutex{},
Client: *client,
logger: logger,
config: config,
headerLck: sync.Mutex{},
bandwidthTracker: bandwidthTracker,
}, nil
}

func validateConfig(_ *httpClientConfig) error {
return nil
}

func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, profiles.ClientProfile, error) {
func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, *bandwidthTracker, profiles.ClientProfile, error) {
var dialer proxy.ContextDialer
dialer = newDirectDialer(config.timeout, config.localAddr, config.dialer)

if config.proxyUrl != "" {
proxyDialer, err := newConnectDialer(config.proxyUrl, config.timeout, config.localAddr, config.dialer, logger)
if err != nil {
return nil, profiles.ClientProfile{}, err
return nil, nil, profiles.ClientProfile{}, err
}

dialer = proxyDialer
Expand All @@ -136,11 +141,13 @@ func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, pro
}
}

bandwidthTracker := newBandwidthTracker()

clientProfile := config.clientProfile

transport, err := newRoundTripper(clientProfile, config.transportOptions, config.serverNameOverwrite, config.insecureSkipVerify, config.withRandomTlsExtensionOrder, config.forceHttp1, config.certificatePins, config.badPinHandler, config.disableIPV6, dialer)
transport, err := newRoundTripper(clientProfile, config.transportOptions, config.serverNameOverwrite, config.insecureSkipVerify, config.withRandomTlsExtensionOrder, config.forceHttp1, config.certificatePins, config.badPinHandler, config.disableIPV6, bandwidthTracker, dialer)
if err != nil {
return nil, clientProfile, err
return nil, nil, clientProfile, err
}

client := &http.Client{
Expand All @@ -153,7 +160,7 @@ func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, pro
client.Jar = config.cookieJar
}

return client, clientProfile, nil
return client, bandwidthTracker, clientProfile, nil
}

// CloseIdleConnections closes all idle connections of the underlying http client.
Expand Down Expand Up @@ -230,7 +237,7 @@ func (c *httpClient) applyProxy() error {
dialer = proxyDialer
}

transport, err := newRoundTripper(c.config.clientProfile, c.config.transportOptions, c.config.serverNameOverwrite, c.config.insecureSkipVerify, c.config.withRandomTlsExtensionOrder, c.config.forceHttp1, c.config.certificatePins, c.config.badPinHandler, c.config.disableIPV6, dialer)
transport, err := newRoundTripper(c.config.clientProfile, c.config.transportOptions, c.config.serverNameOverwrite, c.config.insecureSkipVerify, c.config.withRandomTlsExtensionOrder, c.config.forceHttp1, c.config.certificatePins, c.config.badPinHandler, c.config.disableIPV6, c.bandwidthTracker, dialer)
if err != nil {
return err
}
Expand Down Expand Up @@ -273,6 +280,11 @@ func (c *httpClient) GetCookieJar() http.CookieJar {
return c.Jar
}

// GetBandwidthTracker returns the bandwidth tracker
func (c *httpClient) GetBandwidthTracker() BandwidthTracker {
return c.bandwidthTracker
}

// Do issues a given HTTP request and returns the corresponding response.
//
// If the returned error is nil, the response contains a non-nil body, which the user is expected to close.
Expand Down
9 changes: 7 additions & 2 deletions roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type roundTripper struct {

forceHttp1 bool

bandwidthTracker *bandwidthTracker

headerPriority *http2.PriorityParam
clientSessionCache tls.ClientSessionCache

Expand Down Expand Up @@ -147,7 +149,9 @@ func (rt *roundTripper) dialTLS(ctx context.Context, network, addr string) (net.
tlsConfig.KeyLogWriter = rt.transportOptions.KeyLogWriter
}

conn := tls.UClient(rawConn, tlsConfig, rt.clientHelloId, rt.withRandomTlsExtensionOrder, rt.forceHttp1)
trackedConn := newBandwidthTrackedConn(rawConn, rt.bandwidthTracker)

conn := tls.UClient(trackedConn, tlsConfig, rt.clientHelloId, rt.withRandomTlsExtensionOrder, rt.forceHttp1)
if err = conn.HandshakeContext(ctx); err != nil {
_ = conn.Close()

Expand Down Expand Up @@ -307,7 +311,7 @@ func (rt *roundTripper) getDialTLSAddr(req *http.Request) string {
return net.JoinHostPort(req.URL.Host, "443")
}

func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, dialer ...proxy.ContextDialer) (http.RoundTripper, error) {
func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, bandwidthTracker *bandwidthTracker, dialer ...proxy.ContextDialer) (http.RoundTripper, error) {
pinner, err := NewCertificatePinner(certificatePins)
if err != nil {
return nil, fmt.Errorf("can not instantiate certificate pinner: %w", err)
Expand Down Expand Up @@ -341,6 +345,7 @@ func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *Tra
cachedTransports: make(map[string]http.RoundTripper),
cachedConnections: make(map[string]net.Conn),
disableIPV6: disableIPV6,
bandwidthTracker: bandwidthTracker,
}

if len(dialer) > 0 {
Expand Down