diff --git a/client/api.go b/client/api.go index 7e484d1d15..8fce1d7d06 100644 --- a/client/api.go +++ b/client/api.go @@ -37,6 +37,7 @@ import ( issuingdispute "github.com/stripe/stripe-go/issuing/dispute" "github.com/stripe/stripe-go/issuing/transaction" "github.com/stripe/stripe-go/loginlink" + "github.com/stripe/stripe-go/oauth" "github.com/stripe/stripe-go/order" "github.com/stripe/stripe-go/orderreturn" "github.com/stripe/stripe-go/paymentintent" @@ -148,6 +149,8 @@ type API struct { IssuingTransactions *transaction.Client // LoginLinks is the client used to invoke login link related APIs. LoginLinks *loginlink.Client + // OAuth is the client used to invoke /oauth APIs. + OAuth *oauth.Client // Orders is the client used to invoke /orders APIs. Orders *order.Client // OrderReturns is the client used to invoke /order_returns APIs. @@ -234,6 +237,7 @@ func (a *API) Init(key string, backends *stripe.Backends) { if backends == nil { backends = &stripe.Backends{ API: stripe.GetBackend(stripe.APIBackend), + Connect: stripe.GetBackend(stripe.ConnectBackend), Uploads: stripe.GetBackend(stripe.UploadsBackend), } } @@ -272,8 +276,9 @@ func (a *API) Init(key string, backends *stripe.Backends) { a.IssuingDisputes = &issuingdispute.Client{B: backends.API, Key: key} a.IssuingTransactions = &transaction.Client{B: backends.API, Key: key} a.LoginLinks = &loginlink.Client{B: backends.API, Key: key} - a.Orders = &order.Client{B: backends.API, Key: key} + a.OAuth = &oauth.Client{B: backends.Connect, Key: key} a.OrderReturns = &orderreturn.Client{B: backends.API, Key: key} + a.Orders = &order.Client{B: backends.API, Key: key} a.PaymentIntents = &paymentintent.Client{B: backends.API, Key: key} a.PaymentMethods = &paymentmethod.Client{B: backends.API, Key: key} a.PaymentSource = &paymentsource.Client{B: backends.API, Key: key} diff --git a/error.go b/error.go index 9934c3c33c..71e9fa9184 100644 --- a/error.go +++ b/error.go @@ -136,6 +136,10 @@ type Error struct { SetupIntent *SetupIntent `json:"setup_intent,omitempty"` Source *PaymentSource `json:"source,omitempty"` Type ErrorType `json:"type"` + + // OAuth specific Error properties. Named OAuthError because of name conflict. + OAuthError string `json:"error,omitempty"` + OAuthErrorDescription string `json:"error_description,omitempty"` } // Error serializes the error object to JSON and returns it as a string. diff --git a/form/form.go b/form/form.go index 8c3e52746d..8008ea875f 100644 --- a/form/form.go +++ b/form/form.go @@ -538,14 +538,18 @@ func (f *Values) Add(key, val string) { f.values = append(f.values, formValue{key, val}) } -// Encode encodes the values into “URL encoded” form ("bar=baz&foo=quux"). +// Encode encodes the keys and values into “URL encoded” form +// ("bar=baz&foo=quux"). func (f *Values) Encode() string { var buf bytes.Buffer for _, v := range f.values { if buf.Len() > 0 { buf.WriteByte('&') } - buf.WriteString(url.QueryEscape(v.Key)) + key := url.QueryEscape(v.Key) + key = strings.Replace(key, "%5B", "[", -1) + key = strings.Replace(key, "%5D", "]", -1) + buf.WriteString(key) buf.WriteString("=") buf.WriteString(url.QueryEscape(v.Value)) } diff --git a/issuing_cardholder.go b/issuing_cardholder.go index c1dec90ad3..0fd38d31f6 100644 --- a/issuing_cardholder.go +++ b/issuing_cardholder.go @@ -21,7 +21,7 @@ const ( IssuingCardholderTypeIndividual IssuingCardholderType = "individual" ) -// IssuingBillingParams isis the set of parameters that can be used for billing with the Issuing APIs. +// IssuingBillingParams is the set of parameters that can be used for billing with the Issuing APIs. type IssuingBillingParams struct { Address *AddressParams `form:"address"` Name *string `form:"name"` diff --git a/oauth.go b/oauth.go new file mode 100644 index 0000000000..a89e9c2246 --- /dev/null +++ b/oauth.go @@ -0,0 +1,129 @@ +package stripe + +// OAuthScopeType is the type of OAuth scope. +type OAuthScopeType string + +// List of possible values for OAuth scopes. +const ( + OAuthScopeTypeReadOnly OAuthScopeType = "read_only" + OAuthScopeTypeReadWrite OAuthScopeType = "read_write" +) + +// OAuthTokenType is the type of token. This will always be "bearer." +type OAuthTokenType string + +// List of possible OAuthTokenType values. +const ( + OAuthTokenTypeBearer OAuthTokenType = "bearer" +) + +// OAuthStripeUserBusinessType is the business type for the Stripe oauth user. +type OAuthStripeUserBusinessType string + +// List of supported values for business type. +const ( + OAuthStripeUserBusinessTypeCorporation OAuthStripeUserBusinessType = "corporation" + OAuthStripeUserBusinessTypeLLC OAuthStripeUserBusinessType = "llc" + OAuthStripeUserBusinessTypeNonProfit OAuthStripeUserBusinessType = "non_profit" + OAuthStripeUserBusinessTypePartnership OAuthStripeUserBusinessType = "partnership" + OAuthStripeUserBusinessTypeSoleProp OAuthStripeUserBusinessType = "sole_prop" +) + +// OAuthStripeUserGender of the person who will be filling out a Stripe +// application. (International regulations require either male or female.) +type OAuthStripeUserGender string + +// The gender of the person who will be filling out a Stripe application. +// (International regulations require either male or female.) +const ( + OAuthStripeUserGenderFemale OAuthStripeUserGender = "female" + OAuthStripeUserGenderMale OAuthStripeUserGender = "male" +) + +// OAuthStripeUserParams for the stripe_user OAuth Authorize params. +type OAuthStripeUserParams struct { + BlockKana *string `form:"block_kana"` + BlockKanji *string `form:"block_kanji"` + BuildingKana *string `form:"building_kana"` + BuildingKanji *string `form:"building_kanji"` + BusinessName *string `form:"business_name"` + BusinessType *string `form:"business_type"` + City *string `form:"city"` + Country *string `form:"country"` + Currency *string `form:"currency"` + DOBDay *int64 `form:"dob_day"` + DOBMonth *int64 `form:"dob_month"` + DOBYear *int64 `form:"dob_year"` + Email *string `form:"email"` + FirstName *string `form:"first_name"` + FirstNameKana *string `form:"first_name_kana"` + FirstNameKanji *string `form:"first_name_kanji"` + Gender *string `form:"gender"` + LastName *string `form:"last_name"` + LastNameKana *string `form:"last_name_kana"` + LastNameKanji *string `form:"last_name_kanji"` + PhoneNumber *string `form:"phone_number"` + PhysicalProduct *bool `form:"physical_product"` + ProductDescription *string `form:"product_description"` + State *string `form:"state"` + StreetAddress *string `form:"street_address"` + URL *string `form:"url"` + Zip *string `form:"zip"` +} + +// AuthorizeURLParams for creating OAuth AuthorizeURL's. +type AuthorizeURLParams struct { + Params `form:"*"` + AlwaysPrompt *bool `form:"always_prompt"` + ClientID *string `form:"client_id"` + RedirectURI *string `form:"redirect_uri"` + ResponseType *string `form:"response_type"` + Scope *string `form:"scope"` + State *string `form:"state"` + StripeLanding *string `form:"stripe_landing"` + StripeUser *OAuthStripeUserParams `form:"stripe_user"` + SuggestedCapabilities []*string `form:"suggested_capabilities"` + + // Express is not sent as a parameter, but is used to modify the authorize URL + // path to use the express OAuth path. + Express *bool `form:"-"` +} + +// DeauthorizeParams for deauthorizing an account. +type DeauthorizeParams struct { + Params `form:"*"` + ClientID *string `form:"client_id"` + StripeUserID *string `form:"stripe_user_id"` +} + +// OAuthTokenParams is the set of paramaters that can be used to request +// OAuthTokens. +type OAuthTokenParams struct { + Params `form:"*"` + AssertCapabilities []*string `form:"assert_capabilities"` + ClientSecret *string `form:"client_secret"` + Code *string `form:"code"` + GrantType *string `form:"grant_type"` + RefreshToken *string `form:"refresh_token"` + Scope *string `form:"scope"` +} + +// OAuthToken is the value of the OAuthToken from OAuth flow. +// https://stripe.com/docs/connect/oauth-reference#post-token +type OAuthToken struct { + Livemode bool `json:"livemode"` + Scope OAuthScopeType `json:"scope"` + StripeUserID string `json:"stripe_user_id"` + TokenType OAuthTokenType `json:"token_type"` + + // Deprecated, please use StripeUserID + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + StripePublishableKey string `json:"stripe_publishable_key"` +} + +// Deauthorize is the value of the return from deauthorizing. +// https://stripe.com/docs/connect/oauth-reference#post-deauthorize +type Deauthorize struct { + StripeUserID string `json:"stripe_user_id"` +} diff --git a/oauth/client.go b/oauth/client.go new file mode 100644 index 0000000000..23a905695a --- /dev/null +++ b/oauth/client.go @@ -0,0 +1,77 @@ +// Package oauth provides the OAuth APIs +package oauth + +import ( + "fmt" + "net/http" + + stripe "github.com/stripe/stripe-go" + "github.com/stripe/stripe-go/form" +) + +// Client is used to invoke /oauth and related APIs. +type Client struct { + B stripe.Backend + Key string +} + +// AuthorizeURL builds an OAuth authorize URL. +func AuthorizeURL(params *stripe.AuthorizeURLParams) string { + return getC().AuthorizeURL(params) +} + +// AuthorizeURL builds an OAuth authorize URL. +func (c Client) AuthorizeURL(params *stripe.AuthorizeURLParams) string { + express := "" + if stripe.BoolValue(params.Express) { + express = "/express" + } + qs := &form.Values{} + form.AppendTo(qs, params) + return fmt.Sprintf( + "%s%s/oauth/authorize?%s", + stripe.ConnectURL, + express, + qs.Encode(), + ) +} + +// New creates an OAuth token using a code after successful redirection back. +func New(params *stripe.OAuthTokenParams) (*stripe.OAuthToken, error) { + return getC().New(params) +} + +// New creates an OAuth token using a code after successful redirection back. +func (c Client) New(params *stripe.OAuthTokenParams) (*stripe.OAuthToken, error) { + // client_secret is sent in the post body for this endpoint. + if stripe.StringValue(params.ClientSecret) == "" { + params.ClientSecret = stripe.String(stripe.Key) + } + + oauthToken := &stripe.OAuthToken{} + err := c.B.Call(http.MethodPost, "/oauth/token", c.Key, params, oauthToken) + + return oauthToken, err +} + +// Del deauthorizes a connected account. +func Del(params *stripe.DeauthorizeParams) (*stripe.Deauthorize, error) { + return getC().Del(params) +} + +// Del deauthorizes a connected account. +func (c Client) Del(params *stripe.DeauthorizeParams) (*stripe.Deauthorize, error) { + deauthorization := &stripe.Deauthorize{} + err := c.B.Call( + http.MethodPost, + "/oauth/deauthorize", + c.Key, + params, + deauthorization, + ) + return deauthorization, err +} + +func getC() Client { + return Client{stripe.GetBackend(stripe.ConnectBackend), stripe.Key} +} diff --git a/oauth/client_test.go b/oauth/client_test.go new file mode 100644 index 0000000000..93927e5595 --- /dev/null +++ b/oauth/client_test.go @@ -0,0 +1,267 @@ +package oauth + +import ( + "bytes" + "io/ioutil" + "net/http" + "testing" + + assert "github.com/stretchr/testify/require" + stripe "github.com/stripe/stripe-go" + _ "github.com/stripe/stripe-go/testing" +) + +func TestAuthorizeURL(t *testing.T) { + url := AuthorizeURL(&stripe.AuthorizeURLParams{}) + + assert.Equal(t, url, "https://connect.stripe.com/oauth/authorize?") +} + +func TestAuthorizeURLWithOptionalArgs(t *testing.T) { + url := AuthorizeURL(&stripe.AuthorizeURLParams{ + ClientID: stripe.String("ca_123"), + State: stripe.String("test-state"), + Scope: stripe.String("read_only"), + RedirectURI: stripe.String("https://t.example.com"), + ResponseType: stripe.String("test-code"), + StripeLanding: stripe.String("register"), + AlwaysPrompt: stripe.Bool(true), + Express: stripe.Bool(true), + SuggestedCapabilities: stripe.StringSlice([]string{"card_payments"}), + }) + + assert.Contains(t, url, "https://connect.stripe.com/express/oauth/authorize?") + assert.Contains(t, url, "client_id=ca_123") + assert.Contains(t, url, "redirect_uri=https%3A%2F%2Ft.example.com") + assert.Contains(t, url, "response_type=test-code") + assert.Contains(t, url, "scope=read_only") + assert.Contains(t, url, "state=test-state") + assert.Contains(t, url, "stripe_landing=register") + assert.Contains(t, url, "always_prompt=true") + assert.Contains(t, url, "suggested_capabilities[0]=card_payments") +} + +func TestAuthorizeURLWithStripeUser(t *testing.T) { + url := AuthorizeURL(&stripe.AuthorizeURLParams{ + ResponseType: stripe.String("test-code"), + StripeUser: &stripe.OAuthStripeUserParams{ + BlockKana: stripe.String("block-kana"), + BlockKanji: stripe.String("block-kanji"), + BuildingKana: stripe.String("building-kana"), + BuildingKanji: stripe.String("building-kanji"), + BusinessName: stripe.String("b-name"), + BusinessType: stripe.String("llc"), + City: stripe.String("Elko"), + Country: stripe.String("US"), + Currency: stripe.String("USD"), + DOBDay: stripe.Int64(15), + DOBMonth: stripe.Int64(10), + DOBYear: stripe.Int64(2019), + Email: stripe.String("test@example.com"), + FirstName: stripe.String("first-name"), + FirstNameKana: stripe.String("first-name-kana"), + FirstNameKanji: stripe.String("first-name-kanji"), + Gender: stripe.String("female"), + LastName: stripe.String("last-name"), + LastNameKana: stripe.String("last-name-kana"), + LastNameKanji: stripe.String("last-name-kanji"), + PhoneNumber: stripe.String("999-999-9999"), + PhysicalProduct: stripe.Bool(false), + ProductDescription: stripe.String("product-description"), + State: stripe.String("NV"), + StreetAddress: stripe.String("123 main"), + URL: stripe.String("http://example.com"), + Zip: stripe.String("12345"), + }, + }) + + assert.Contains(t, url, "response_type=test-code") + assert.Contains(t, url, "stripe_user[block_kana]=block-kana") + assert.Contains(t, url, "stripe_user[block_kanji]=block-kanji") + assert.Contains(t, url, "stripe_user[building_kana]=building-kana") + assert.Contains(t, url, "stripe_user[building_kanji]=building-kanji") + assert.Contains(t, url, "stripe_user[business_name]=b-name") + assert.Contains(t, url, "stripe_user[business_type]=llc") + assert.Contains(t, url, "stripe_user[city]=Elko") + assert.Contains(t, url, "stripe_user[country]=US") + assert.Contains(t, url, "stripe_user[currency]=USD") + assert.Contains(t, url, "stripe_user[dob_day]=15") + assert.Contains(t, url, "stripe_user[dob_month]=10") + assert.Contains(t, url, "stripe_user[dob_year]=2019") + assert.Contains(t, url, "stripe_user[email]=test%40example.com") + assert.Contains(t, url, "stripe_user[first_name]=first-name") + assert.Contains(t, url, "stripe_user[first_name_kana]=first-name-kana") + assert.Contains(t, url, "stripe_user[first_name_kanji]=first-name-kanji") + assert.Contains(t, url, "stripe_user[gender]=female") + assert.Contains(t, url, "stripe_user[last_name]=last-name") + assert.Contains(t, url, "stripe_user[last_name_kana]=last-name-kana") + assert.Contains(t, url, "stripe_user[last_name_kanji]=last-name-kanji") + assert.Contains(t, url, "stripe_user[phone_number]=999-999-9999") + assert.Contains(t, url, "stripe_user[physical_product]=false") + assert.Contains(t, url, "stripe_user[product_description]=product-description") + assert.Contains(t, url, "stripe_user[state]=NV") + assert.Contains(t, url, "stripe_user[street_address]=123+main") + assert.Contains(t, url, "stripe_user[url]=http%3A%2F%2Fexample.com") + assert.Contains(t, url, "stripe_user[zip]=12345") +} + +// RoundTripFunc. +type RoundTripFunc func(req *http.Request) *http.Response + +// RoundTrip. +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// NewTestClient returns *http.Client with Transport replaced to avoid making +// real calls +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: RoundTripFunc(fn), + } +} + +func StubConnectBackend(httpClient *http.Client) { + mockBackend := stripe.GetBackendWithConfig( + stripe.ConnectBackend, + &stripe.BackendConfig{ + URL: "https://localhost:12113", + HTTPClient: httpClient, + }, + ) + stripe.SetBackend(stripe.ConnectBackend, mockBackend) +} + +func TestNewOAuthToken(t *testing.T) { + stripe.Key = "sk_123" + + // stripe-mock doesn't support connect URL's so this stubs out the server. + httpClient := NewTestClient(func(req *http.Request) *http.Response { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + reqBody := buf.String() + + assert.Contains(t, req.URL.String(), "https://localhost:12113/oauth/token") + assert.Contains(t, reqBody, "client_secret=sk_123") + assert.Contains(t, reqBody, "grant_type=authorization_code") + assert.Contains(t, reqBody, "code=code") + assert.Contains(t, reqBody, "assert_capabilities[0]=card_payments") + + responseBody := `{ + "access_token":"sk_123", + "livemode":false, + "refresh_token":"rt_123", + "token_type":"bearer", + "stripe_publishable_key":"pk_123", + "stripe_user_id":"acct_123", + "scope":"read_write" + }` + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(responseBody)), + Header: make(http.Header), + } + }) + StubConnectBackend(httpClient) + + token, err := New(&stripe.OAuthTokenParams{ + GrantType: stripe.String("authorization_code"), + Code: stripe.String("code"), + AssertCapabilities: stripe.StringSlice([]string{"card_payments"}), + }) + assert.Nil(t, err) + assert.NotNil(t, token) + assert.Equal(t, token.AccessToken, "sk_123") + assert.Equal(t, token.Livemode, false) + assert.Equal(t, token.RefreshToken, "rt_123") + assert.Equal(t, token.TokenType, stripe.OAuthTokenTypeBearer) + assert.Equal(t, token.StripePublishableKey, "pk_123") + assert.Equal(t, token.StripeUserID, "acct_123") + assert.Equal(t, token.Scope, stripe.OAuthScopeTypeReadWrite) +} + +func TestNewOAuthTokenWithCustomKey(t *testing.T) { + stripe.Key = "sk_123" + // stripe-mock doesn't support connect URL's so this stubs out the server. + httpClient := NewTestClient(func(req *http.Request) *http.Response { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + reqBody := buf.String() + assert.Contains(t, reqBody, "client_secret=sk_999") + + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{}`)), + Header: make(http.Header), + } + }) + StubConnectBackend(httpClient) + + token, err := New(&stripe.OAuthTokenParams{ + ClientSecret: stripe.String("sk_999"), + }) + assert.Nil(t, err) + assert.NotNil(t, token) +} + +func TestNewOAuthTokenWithError(t *testing.T) { + stripe.Key = "sk_123" + // stripe-mock doesn't support connect URL's so this stubs out the server. + + responseBody := `{"error":"invalid_grant","error_description": "Authorization code does not exist"}` + httpClient := NewTestClient(func(req *http.Request) *http.Response { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + reqBody := buf.String() + assert.Contains(t, reqBody, "client_secret=sk_999") + + return &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(bytes.NewBufferString(responseBody)), + Header: make(http.Header), + } + }) + StubConnectBackend(httpClient) + + token, err := New(&stripe.OAuthTokenParams{ + ClientSecret: stripe.String("sk_999"), + }) + + assert.NotNil(t, token) + assert.NotNil(t, err) + + stripeErr := err.(*stripe.Error) + assert.Equal(t, 400, stripeErr.HTTPStatusCode) + assert.Equal(t, "Authorization code does not exist", stripeErr.OAuthErrorDescription) + assert.Equal(t, "invalid_grant", stripeErr.OAuthError) +} + +func TestDeauthorize(t *testing.T) { + stripe.Key = "sk_123" + + // stripe-mock doesn't support connect URL's so this stubs out the server. + httpClient := NewTestClient(func(req *http.Request) *http.Response { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + reqBody := buf.String() + assert.Contains(t, req.URL.String(), "https://localhost:12113/oauth/deauthorize") + assert.Contains(t, reqBody, "client_id=sk_999") + assert.Contains(t, reqBody, "stripe_user_id=acct_123") + + resBody := `{"stripe_user_id": "acct_123"}` + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(resBody)), + Header: make(http.Header), + } + }) + StubConnectBackend(httpClient) + + deauthorization, err := Del(&stripe.DeauthorizeParams{ + ClientID: stripe.String("sk_999"), + StripeUserID: stripe.String("acct_123"), + }) + assert.Nil(t, err) + assert.NotNil(t, deauthorization) + assert.Equal(t, deauthorization.StripeUserID, "acct_123") +} diff --git a/stripe.go b/stripe.go index 706f70fe34..5d9556d37a 100644 --- a/stripe.go +++ b/stripe.go @@ -36,6 +36,13 @@ const ( // APIURL is the URL of the API service backend. APIURL string = "https://api.stripe.com" + // ConnectURL is the URL for OAuth. + ConnectURL string = "https://connect.stripe.com" + + // ConnectBackend is a constant representing the connect service backend for + // OAuth. + ConnectBackend SupportedBackend = "connect" + // UnknownPlatform is the string returned as the system name if we couldn't get // one from `uname`. UnknownPlatform string = "unknown platform" @@ -451,8 +458,18 @@ func (s *BackendImplementation) Do(req *http.Request, body *bytes.Buffer, v inte // ResponseToError converts a stripe response to an Error. func (s *BackendImplementation) ResponseToError(res *http.Response, resBody []byte) error { var raw rawError - if err := s.UnmarshalJSONVerbose(res.StatusCode, resBody, &raw); err != nil { - return err + if s.Type == ConnectBackend { + // If this is an OAuth request, deserialize as Error because OAuth errors + // are a different shape from the standard API errors. + var topLevelError rawErrorInternal + if err := s.UnmarshalJSONVerbose(res.StatusCode, resBody, &topLevelError); err != nil { + return err + } + raw.E = &topLevelError + } else { + if err := s.UnmarshalJSONVerbose(res.StatusCode, resBody, &raw); err != nil { + return err + } } // no error in resBody @@ -599,8 +616,8 @@ func (s *BackendImplementation) sleepTime(numRetries int) time.Duration { // Backends are the currently supported endpoints. type Backends struct { - API, Uploads Backend - mu sync.RWMutex + API, Connect, Uploads Backend + mu sync.RWMutex } // SupportedBackend is an enumeration of supported Stripe endpoints. @@ -689,6 +706,8 @@ func GetBackend(backendType SupportedBackend) Backend { switch backendType { case APIBackend: backend = backends.API + case ConnectBackend: + backend = backends.Connect case UploadsBackend: backend = backends.Uploads } @@ -715,6 +734,8 @@ func GetBackend(backendType SupportedBackend) Backend { switch backendType { case APIBackend: backends.API = backend + case ConnectBackend: + backends.Connect = backend case UploadsBackend: backends.Uploads = backend } @@ -758,6 +779,15 @@ func GetBackendWithConfig(backendType SupportedBackend, config *BackendConfig) B config.URL = normalizeURL(config.URL) + return newBackendImplementation(backendType, config) + + case ConnectBackend: + if config.URL == "" { + config.URL = ConnectURL + } + + config.URL = normalizeURL(config.URL) + return newBackendImplementation(backendType, config) } @@ -791,9 +821,11 @@ func Int64Slice(v []int64) []*int64 { // should only need to use this for testing purposes or on App Engine. func NewBackends(httpClient *http.Client) *Backends { apiConfig := &BackendConfig{HTTPClient: httpClient} + connectConfig := &BackendConfig{HTTPClient: httpClient} uploadConfig := &BackendConfig{HTTPClient: httpClient} return &Backends{ API: GetBackendWithConfig(APIBackend, apiConfig), + Connect: GetBackendWithConfig(ConnectBackend, connectConfig), Uploads: GetBackendWithConfig(UploadsBackend, uploadConfig), } } @@ -839,6 +871,8 @@ func SetBackend(backend SupportedBackend, b Backend) { switch backend { case APIBackend: backends.API = b + case ConnectBackend: + backends.Connect = b case UploadsBackend: backends.Uploads = b }