Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): add support for setting quota project with envvar #1892

Merged
merged 10 commits into from
Mar 10, 2023
21 changes: 16 additions & 5 deletions internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io/ioutil"
"net"
"net/http"
"os"
"time"

"golang.org/x/oauth2"
Expand All @@ -21,6 +22,8 @@ import (
"golang.org/x/oauth2/google"
)

const quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT"

// Creds returns credential information obtained from DialSettings, or if none, then
// it returns default credential information.
func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) {
Expand Down Expand Up @@ -152,14 +155,22 @@ func selfSignedJWTTokenSource(data []byte, ds *DialSettings) (oauth2.TokenSource
}
}

// QuotaProjectFromCreds returns the quota project from the JSON blob in the provided credentials.
//
// NOTE(cbro): consider promoting this to a field on google.Credentials.
func QuotaProjectFromCreds(cred *google.Credentials) string {
// GetQuotaProject retrieves quota project with precedence being: client option,
// environment variable, creds file.
func GetQuotaProject(creds *google.Credentials, clientOpt string) string {
if clientOpt != "" {
return clientOpt
}
if env := os.Getenv(quotaProjectEnvVar); env != "" {
return env
}
if creds == nil {
return ""
}
codyoss marked this conversation as resolved.
Show resolved Hide resolved
var v struct {
QuotaProject string `json:"quota_project_id"`
}
if err := json.Unmarshal(cred.JSON, &v); err != nil {
if err := json.Unmarshal(creds.JSON, &v); err != nil {
return ""
}
return v.QuotaProject
Expand Down
61 changes: 51 additions & 10 deletions internal/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package internal

import (
"context"
"os"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -199,10 +200,9 @@ const validServiceAccountJSON = `{
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/dumba-504%40appspot.gserviceaccount.com"
}`

func TestQuotaProjectFromCreds(t *testing.T) {
func TestGetQuotaProject(t *testing.T) {
ctx := context.Background()

cred, err := credentialsFromJSON(
emptyCred, err := credentialsFromJSON(
ctx,
[]byte(validServiceAccountJSON),
&DialSettings{
Expand All @@ -212,17 +212,13 @@ func TestQuotaProjectFromCreds(t *testing.T) {
if err != nil {
t.Fatalf("got %v, wanted no error", err)
}
if want, got := "", QuotaProjectFromCreds(cred); want != got {
t.Errorf("QuotaProjectFromCreds(validServiceAccountJSON): want %q, got %q", want, got)
}

quotaProjectJSON := []byte(`
{
"type": "authorized_user",
"quota_project_id": "foobar"
}`)

cred, err = credentialsFromJSON(
quotaCred, err := credentialsFromJSON(
ctx,
[]byte(quotaProjectJSON),
&DialSettings{
Expand All @@ -232,8 +228,53 @@ func TestQuotaProjectFromCreds(t *testing.T) {
if err != nil {
t.Fatalf("got %v, wanted no error", err)
}
if want, got := "foobar", QuotaProjectFromCreds(cred); want != got {
t.Errorf("QuotaProjectFromCreds(quotaProjectJSON): want %q, got %q", want, got)

tests := []struct {
name string
cred *google.Credentials
clientOpt string
env string
want string
}{
{
name: "empty all",
cred: nil,
want: "",
},
{
name: "empty cred",
cred: emptyCred,
want: "",
},
{
name: "from cred",
cred: quotaCred,
want: "foobar",
},
{
name: "from opt",
cred: quotaCred,
clientOpt: "clientopt",
want: "clientopt",
},
{
name: "from env",
cred: quotaCred,
env: "envProject",
want: "envProject",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldEnv := os.Getenv(quotaProjectEnvVar)
if tt.env != "" {
os.Setenv(quotaProjectEnvVar, tt.env)
}
if want, got := tt.want, GetQuotaProject(tt.cred, tt.clientOpt); want != got {
t.Errorf("GetQuotaProject(%v, %q): want %q, got %q", tt.cred, tt.clientOpt, want, got)
}
os.Setenv(quotaProjectEnvVar, oldEnv)
})
}
}

Expand Down
6 changes: 1 addition & 5 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,10 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
return nil, err
}

if o.QuotaProject == "" {
o.QuotaProject = internal.QuotaProjectFromCreds(creds)
}

grpcOpts = append(grpcOpts,
grpc.WithPerRPCCredentials(grpcTokenSource{
TokenSource: oauth.TokenSource{creds.TokenSource},
quotaProject: o.QuotaProject,
quotaProject: internal.GetQuotaProject(creds, o.QuotaProject),
requestReason: o.RequestReason,
}),
)
Expand Down
7 changes: 2 additions & 5 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
paramTransport := &parameterTransport{
base: base,
userAgent: settings.UserAgent,
quotaProject: settings.QuotaProject,
requestReason: settings.RequestReason,
}
var trans http.RoundTripper = paramTransport
Expand All @@ -74,6 +73,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
case settings.NoAuth:
// Do nothing.
case settings.APIKey != "":
paramTransport.quotaProject = internal.GetQuotaProject(nil, settings.QuotaProject)
trans = &transport.APIKey{
Transport: trans,
Key: settings.APIKey,
Expand All @@ -83,10 +83,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
if err != nil {
return nil, err
}
if paramTransport.quotaProject == "" {
paramTransport.quotaProject = internal.QuotaProjectFromCreds(creds)
}

paramTransport.quotaProject = internal.GetQuotaProject(creds, settings.QuotaProject)
ts := creds.TokenSource
if settings.ImpersonationConfig == nil && settings.TokenSource != nil {
ts = settings.TokenSource
Expand Down