Skip to content

Commit

Permalink
Add user agent suffix feature flag (#3297)
Browse files Browse the repository at this point in the history
* Add user agent suffix feature flag

* unecessary concat
  • Loading branch information
dustin-decker authored Sep 13, 2024
1 parent 213bf7e commit 7e78ca3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
10 changes: 9 additions & 1 deletion pkg/common/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/hashicorp/go-retryablehttp"
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
)

var caCerts = []string{
Expand Down Expand Up @@ -88,8 +89,15 @@ type CustomTransport struct {
T http.RoundTripper
}

func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}

func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", "TruffleHog")
req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req)
}

Expand Down
25 changes: 16 additions & 9 deletions pkg/detectors/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,30 @@ import (
"net"
"net/http"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
)

var DetectorHttpClientWithNoLocalAddresses *http.Client
var DetectorHttpClientWithLocalAddresses *http.Client

const DefaultResponseTimeout = 5 * time.Second
const DefaultUserAgent = "TruffleHog"

func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}

func init() {
DetectorHttpClientWithLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)),
WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(),
)
DetectorHttpClientWithNoLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)),
WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(),
WithNoLocalIP(),
Expand All @@ -41,12 +49,11 @@ func WithNoFollowRedirects() ClientOption {
}

type detectorTransport struct {
T http.RoundTripper
userAgent string
T http.RoundTripper
}

func (t *detectorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", t.userAgent)
req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req)
}

Expand All @@ -55,7 +62,7 @@ var defaultDialer = &net.Dialer{
KeepAlive: 5 * time.Second,
}

func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripper {
func NewDetectorTransport(T http.RoundTripper) http.RoundTripper {
if T == nil {
T = &http.Transport{
Proxy: http.ProxyFromEnvironment,
Expand All @@ -67,7 +74,7 @@ func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripp
ExpectContinueTimeout: 1 * time.Second,
}
}
return &detectorTransport{T: T, userAgent: userAgent}
return &detectorTransport{T: T}
}

func isLocalIP(ip net.IP) bool {
Expand Down Expand Up @@ -143,7 +150,7 @@ func WithTimeout(timeout time.Duration) ClientOption {

func NewDetectorHttpClient(opts ...ClientOption) *http.Client {
httpClient := &http.Client{
Transport: NewDetectorTransport(DefaultUserAgent, nil),
Transport: NewDetectorTransport(nil),
Timeout: DefaultResponseTimeout,
}

Expand Down
31 changes: 28 additions & 3 deletions pkg/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,32 @@ package feature
import "sync/atomic"

var (
ForceSkipBinaries = atomic.Bool{}
ForceSkipArchives = atomic.Bool{}
SkipAdditionalRefs = atomic.Bool{}
ForceSkipBinaries atomic.Bool
ForceSkipArchives atomic.Bool
SkipAdditionalRefs atomic.Bool
UserAgentSuffix AtomicString
)

type AtomicString struct {
value atomic.Value
}

// Load returns the current value of the atomic string
func (as *AtomicString) Load() string {
if v := as.value.Load(); v != nil {
return v.(string)
}
return ""
}

// Store sets the value of the atomic string
func (as *AtomicString) Store(newValue string) {
as.value.Store(newValue)
}

// Swap atomically swaps the current string with a new one and returns the old value
func (as *AtomicString) Swap(newValue string) string {
oldValue := as.Load()
as.Store(newValue)
return oldValue
}
4 changes: 2 additions & 2 deletions pkg/sources/github/github_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestSource_Token(t *testing.T) {
s.Init(ctx, "github integration test source", 0, 0, false, conn, 1)
s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil)

err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient())
err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient(), noopReporter())
assert.NoError(t, err)

_, _, err = s.cloneRepo(ctx, "https://github.com/truffle-test-integration-org/another-test-repo.git")
Expand Down Expand Up @@ -631,7 +631,7 @@ func TestSource_paginateGists(t *testing.T) {
}
chunksCh := make(chan *sources.Chunk, 5)
go func() {
assert.NoError(t, s.addUserGistsToCache(ctx, tt.user))
assert.NoError(t, s.addUserGistsToCache(ctx, tt.user, noopReporter()))
chunksCh <- &sources.Chunk{}
}()
var wantedRepo string
Expand Down

0 comments on commit 7e78ca3

Please sign in to comment.