diff --git a/cmd/root.go b/cmd/root.go index 2ee7996c..1a9cce0e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -48,13 +48,13 @@ var ( //go:embed version.txt versionString string // metadataString indiciates additional build or distribution metadata. - metadataString string - userAgent string + metadataString string + defaultUserAgent string ) func init() { versionString = semanticVersion() - userAgent = "alloy-db-auth-proxy/" + versionString + defaultUserAgent = "alloy-db-auth-proxy/" + versionString } // semanticVersion returns the version of the proxy including an compile-time @@ -98,6 +98,7 @@ type Command struct { httpAddress string httpPort string quiet bool + otherUserAgents string // impersonationChain is a comma separated list of one or more service // accounts. The first entry in the chain is the impersonation target. Any @@ -320,7 +321,7 @@ func NewCommand(opts ...Option) *Command { logger: logger, cleanup: func() error { return nil }, conf: &proxy.Config{ - UserAgent: userAgent, + UserAgent: defaultUserAgent, }, } for _, o := range opts { @@ -355,6 +356,8 @@ func NewCommand(opts ...Option) *Command { pflags := cmd.PersistentFlags() // Global-only flags + pflags.StringVar(&c.otherUserAgents, "user-agent", "", + "Space separated list of additional user agents, e.g. cloud-sql-proxy-operator/0.0.1") pflags.StringVarP(&c.conf.Token, "token", "t", "", "Bearer token used for authorization.") pflags.StringVarP(&c.conf.CredentialsFile, "credentials-file", "c", "", @@ -517,6 +520,11 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { cmd.logger.Infof("Ignoring --disable-traces as --telemetry-project was not set") } + if userHasSet("user-agent") { + defaultUserAgent += " " + cmd.otherUserAgents + conf.UserAgent = defaultUserAgent + } + if cmd.impersonationChain != "" { accts := strings.Split(cmd.impersonationChain, ",") conf.ImpersonateTarget = accts[0] diff --git a/cmd/root_test.go b/cmd/root_test.go index 0bd3d40c..e90ad692 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -22,6 +22,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -52,7 +53,7 @@ func invokeProxyCommand(args []string) (*Command, error) { func withDefaults(c *proxy.Config) *proxy.Config { if c.UserAgent == "" { - c.UserAgent = userAgent + c.UserAgent = defaultUserAgent } if c.Addr == "" { c.Addr = "127.0.0.1" @@ -77,6 +78,43 @@ func withDefaults(c *proxy.Config) *proxy.Config { return c } +func TestUserAgentWithVersionEnvVar(t *testing.T) { + os.Setenv("ALLOYDB_PROXY_USER_AGENT", "some-runtime/0.0.1") + defer os.Unsetenv("ALLOYDB_PROXY_USER_AGENT") + + cmd, err := invokeProxyCommand([]string{ + "projects/proj/locations/region/clusters/clust/instances/inst", + }) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + want := "some-runtime/0.0.1" + got := cmd.conf.UserAgent + if !strings.Contains(got, want) { + t.Errorf("expected user agent to contain: %v; got: %v", want, got) + } +} + +func TestUserAgent(t *testing.T) { + cmd, err := invokeProxyCommand( + []string{ + "--user-agent", + "some-runtime/0.0.1", + "projects/proj/locations/region/clusters/clust/instances/inst", + }, + ) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + want := "some-runtime/0.0.1" + got := cmd.conf.UserAgent + if !strings.Contains(got, want) { + t.Errorf("expected userAgent to contain: %v; got: %v", want, got) + } +} + func TestNewCommandArguments(t *testing.T) { tcs := []struct { desc string