diff --git a/auther.go b/auther.go index 7ad7325..be517af 100644 --- a/auther.go +++ b/auther.go @@ -36,27 +36,28 @@ type clock interface { Now() time.Time } -// auther adds an "OAuth" Authorization header field to requests. -type auther struct { +// DefaultAuther adds an "OAuth" Authorization header field to requests. +type DefaultAuther struct { config *Config clock clock } -func newAuther(config *Config) *auther { +// NewDefaultAuther returns a new DefaultAuther +func NewDefaultAuther(config *Config) *DefaultAuther { if config == nil { config = &Config{} } if config.Noncer == nil { config.Noncer = Base64Noncer{} } - return &auther{ + return &DefaultAuther{ config: config, } } // setRequestTokenAuthHeader adds the OAuth1 header for the request token // request (temporary credential) according to RFC 5849 2.1. -func (a *auther) setRequestTokenAuthHeader(req *http.Request) error { +func (a *DefaultAuther) setRequestTokenAuthHeader(req *http.Request) error { oauthParams := a.commonOAuthParams() oauthParams[oauthCallbackParam] = a.config.CallbackURL params, err := collectParameters(req, oauthParams) @@ -78,7 +79,7 @@ func (a *auther) setRequestTokenAuthHeader(req *http.Request) error { // setAccessTokenAuthHeader sets the OAuth1 header for the access token request // (token credential) according to RFC 5849 2.3. -func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error { +func (a *DefaultAuther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error { oauthParams := a.commonOAuthParams() oauthParams[oauthTokenParam] = requestToken oauthParams[oauthVerifierParam] = verifier @@ -96,9 +97,9 @@ func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, reque return nil } -// setRequestAuthHeader sets the OAuth1 header for making authenticated +// SetRequestAuthHeader sets the OAuth1 header for making authenticated // requests with an AccessToken (token credential) according to RFC 5849 3.1. -func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error { +func (a *DefaultAuther) SetRequestAuthHeader(req *http.Request, accessToken *Token) error { oauthParams := a.commonOAuthParams() oauthParams[oauthTokenParam] = accessToken.Token params, err := collectParameters(req, oauthParams) @@ -119,7 +120,7 @@ func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) err // excluding the oauth_signature parameter. This includes the realm parameter // if it was set in the config. The realm parameter will not be included in // the signature base string as specified in RFC 5849 3.4.1.3.1. -func (a *auther) commonOAuthParams() map[string]string { +func (a *DefaultAuther) commonOAuthParams() map[string]string { params := map[string]string{ oauthConsumerKeyParam: a.config.ConsumerKey, oauthSignatureMethodParam: a.signer().Name(), @@ -134,12 +135,12 @@ func (a *auther) commonOAuthParams() map[string]string { } // Returns a nonce using the configured Noncer. -func (a *auther) nonce() string { +func (a *DefaultAuther) nonce() string { return a.config.Noncer.Nonce() } // Returns the Unix epoch seconds. -func (a *auther) epoch() int64 { +func (a *DefaultAuther) epoch() int64 { if a.clock != nil { return a.clock.Now().Unix() } @@ -147,7 +148,7 @@ func (a *auther) epoch() int64 { } // Returns the Config's Signer or the default Signer. -func (a *auther) signer() Signer { +func (a *DefaultAuther) signer() Signer { if a.config.Signer != nil { return a.config.Signer } diff --git a/auther_test.go b/auther_test.go index 9088e8e..236fd46 100644 --- a/auther_test.go +++ b/auther_test.go @@ -12,11 +12,11 @@ import ( func TestCommonOAuthParams(t *testing.T) { cases := []struct { - auther *auther + DefaultAuther *DefaultAuther expectedParams map[string]string }{ { - &auther{ + &DefaultAuther{ &Config{ ConsumerKey: "some_consumer_key", Noncer: &fixedNoncer{"some_nonce"}, @@ -32,7 +32,7 @@ func TestCommonOAuthParams(t *testing.T) { }, }, { - &auther{ + &DefaultAuther{ &Config{ ConsumerKey: "some_consumer_key", Realm: "photos", @@ -52,13 +52,13 @@ func TestCommonOAuthParams(t *testing.T) { } for _, c := range cases { - assert.Equal(t, c.expectedParams, c.auther.commonOAuthParams()) + assert.Equal(t, c.expectedParams, c.DefaultAuther.commonOAuthParams()) } } func TestNonce(t *testing.T) { - auther := newAuther(nil) - nonce := auther.nonce() + DefaultAuther := NewDefaultAuther(nil) + nonce := DefaultAuther.nonce() // assert that 32 bytes (256 bites) become 44 bytes since a base64 byte // zeros the 2 high bits. 3 bytes convert to 4 base64 bytes, 40 base64 bytes // represent the first 30 of 32 bytes, = padding adds another 4 byte group. @@ -67,17 +67,17 @@ func TestNonce(t *testing.T) { } func TestEpoch(t *testing.T) { - a := newAuther(nil) + a := NewDefaultAuther(nil) // assert that a real time is used by default assert.InEpsilon(t, time.Now().Unix(), a.epoch(), 1) // assert that the fixed clock can be used for testing - a = &auther{clock: &fixedClock{time.Unix(50037133, 0)}} + a = &DefaultAuther{clock: &fixedClock{time.Unix(50037133, 0)}} assert.Equal(t, int64(50037133), a.epoch()) } func TestSigner_Default(t *testing.T) { config := &Config{ConsumerSecret: "consumer_secret"} - a := newAuther(config) + a := NewDefaultAuther(config) // echo -n "hello world" | openssl dgst -sha1 -hmac "consumer_secret&token_secret" -binary | base64 expectedSignature := "BE0uILOruKfSXd4UzYlLJDfOq08=" // assert that the default signer produces the expected HMAC-SHA1 digest @@ -92,7 +92,7 @@ func TestSigner_SHA256(t *testing.T) { config := &Config{ Signer: &HMAC256Signer{ConsumerSecret: "consumer_secret"}, } - a := newAuther(config) + a := NewDefaultAuther(config) // echo -n "hello world" | openssl dgst -sha256 -hmac "consumer_secret&token_secret" -binary | base64 expectedSignature := "pW9drXUyErU8DASWbsP2I3XZbju37AW+VzcGdYSeMo8=" // assert that the signer produces the expected HMAC-SHA256 digest @@ -118,7 +118,7 @@ func TestSigner_Custom(t *testing.T) { ConsumerSecret: "consumer_secret", Signer: &identitySigner{}, } - a := newAuther(config) + a := NewDefaultAuther(config) // assert that the custom signer is used method := a.signer().Name() digest, err := a.signer().Sign("secret", "hello world") diff --git a/config.go b/config.go index c539a00..3162e0c 100644 --- a/config.go +++ b/config.go @@ -50,11 +50,7 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client { // NewClient returns a new http Client which signs requests via OAuth1. func NewClient(ctx context.Context, config *Config, token *Token) *http.Client { - transport := &Transport{ - Base: contextTransport(ctx), - source: StaticTokenSource(token), - auther: newAuther(config), - } + transport := newTransport(contextTransport(ctx), StaticTokenSource(token), NewDefaultAuther(config)) return &http.Client{Transport: transport} } @@ -69,7 +65,7 @@ func (c *Config) RequestToken() (requestToken, requestSecret string, err error) if err != nil { return "", "", err } - err = newAuther(c).setRequestTokenAuthHeader(req) + err = NewDefaultAuther(c).setRequestTokenAuthHeader(req) if err != nil { return "", "", err } @@ -148,7 +144,7 @@ func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (acce if err != nil { return "", "", err } - err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier) + err = NewDefaultAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier) if err != nil { return "", "", err } diff --git a/reference_test.go b/reference_test.go index 95d28cc..881623d 100644 --- a/reference_test.go +++ b/reference_test.go @@ -36,7 +36,7 @@ func TestTwitterRequestTokenAuthHeader(t *testing.T) { Noncer: &fixedNoncer{expectedNonce}, } - auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}} + auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}} req, err := http.NewRequest("POST", config.Endpoint.RequestTokenURL, nil) assert.Nil(t, err) err = auther.setRequestTokenAuthHeader(req) @@ -74,7 +74,7 @@ func TestTwitterAccessTokenAuthHeader(t *testing.T) { Noncer: &fixedNoncer{expectedNonce}, } - auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}} + auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}} req, err := http.NewRequest("POST", config.Endpoint.AccessTokenURL, nil) assert.Nil(t, err) err = auther.setAccessTokenAuthHeader(req, expectedRequestToken, requestTokenSecret, expectedVerifier) @@ -111,7 +111,7 @@ var twitterConfig = &Config{ } func TestTwitterParameterString(t *testing.T) { - auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} + auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} values := url.Values{} values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") // note: the reference example is old and uses api v1 in the URL @@ -128,7 +128,7 @@ func TestTwitterParameterString(t *testing.T) { } func TestTwitterSignatureBase(t *testing.T) { - auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} + auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} values := url.Values{} values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") // note: the reference example is old and uses api v1 in the URL @@ -151,7 +151,7 @@ func TestTwitterRequestAuthHeader(t *testing.T) { expectedSignature := PercentEncode("tnnArxj06cWHq44gCs1OSKk/jLY=") expectedTimestamp := "1318622958" - auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} + auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}} values := url.Values{} values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") @@ -159,7 +159,7 @@ func TestTwitterRequestAuthHeader(t *testing.T) { req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode())) assert.Nil(t, err) req.Header.Set(contentType, formContentType) - err = auther.setRequestAuthHeader(req, accessToken) + err = auther.SetRequestAuthHeader(req, accessToken) // assert that request is signed and has an access token token assert.Nil(t, err) params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) diff --git a/transport.go b/transport.go index c1af993..118f1f5 100644 --- a/transport.go +++ b/transport.go @@ -5,6 +5,10 @@ import ( "net/http" ) +type Auther interface { + SetRequestAuthHeader(req *http.Request, accessToken *Token) error +} + // Transport is an http.RoundTripper which makes OAuth1 HTTP requests. It // wraps a base RoundTripper and adds an Authorization header using the // token from a TokenSource. @@ -18,25 +22,43 @@ type Transport struct { // source supplies the token to use when signing a request source TokenSource // auther adds OAuth1 Authorization headers to requests - auther *auther + auther Auther +} + +// NewTransport returns a new Transport and an error +func NewTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) (*Transport, error) { + t := newTransport(baseRoundTripper, source, auther) + err := t.checkValid() + if err != nil { + return nil, err + } + return t, nil +} + +// newTransport returns a new Transport without checking whether there is an error +// newTransport is only for NewTransport & NewClient +func newTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) *Transport { + return &Transport{ + Base: baseRoundTripper, + source: source, + auther: auther, + } } // RoundTrip authorizes the request with a signed OAuth1 Authorization header // using the auther and TokenSource. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.source == nil { - return nil, fmt.Errorf("oauth1: Transport's source is nil") + err := t.checkValid() + if err != nil { + return nil, err } accessToken, err := t.source.Token() if err != nil { return nil, err } - if t.auther == nil { - return nil, fmt.Errorf("oauth1: Transport's auther is nil") - } // RoundTripper should not modify the given request, clone it req2 := cloneRequest(req) - err = t.auther.setRequestAuthHeader(req2, accessToken) + err = t.auther.SetRequestAuthHeader(req2, accessToken) if err != nil { return nil, err } @@ -63,3 +85,17 @@ func cloneRequest(req *http.Request) *http.Request { } return r2 } + +func (t *Transport) checkValid() error { + if t.source == nil { + return fmt.Errorf("oauth1: Transport's source is nil") + } + if t.auther == nil { + return fmt.Errorf("oauth1: Transport's auther is nil") + } + _, err := t.source.Token() + if err != nil { + return err + } + return nil +} diff --git a/transport_test.go b/transport_test.go index 116c11b..a3c8889 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" ) +var _ Auther = &DefaultAuther{} + func TestTransport(t *testing.T) { const ( expectedToken = "access_token" @@ -34,14 +36,11 @@ func TestTransport(t *testing.T) { ConsumerSecret: "consumer_secret", Noncer: &fixedNoncer{expectedNonce}, } - auther := &auther{ + auther := &DefaultAuther{ config: config, clock: &fixedClock{time.Unix(123456789, 0)}, } - tr := &Transport{ - source: StaticTokenSource(NewToken(expectedToken, "some_secret")), - auther: auther, - } + tr := newTransport(nil, StaticTokenSource(NewToken(expectedToken, "some_secret")), auther) client := &http.Client{Transport: tr} req, err := http.NewRequest("GET", server.URL, nil) @@ -50,13 +49,83 @@ func TestTransport(t *testing.T) { assert.Nil(t, err) } +func TestNewTransport(t *testing.T) { + config := &Config{ + ConsumerKey: "consumer_key", + ConsumerSecret: "consumer_secret", + Noncer: &fixedNoncer{"some_nonce"}, + } + auther := &DefaultAuther{ + config: config, + clock: &fixedClock{time.Unix(123456789, 0)}, + } + source := StaticTokenSource(NewToken("access_token", "some_secret")) + + type testCase struct { + description string + base http.RoundTripper + source TokenSource + auther Auther + errMsg string + } + for _, tc := range []testCase{ + { + description: "default base transport", + base: nil, + source: source, + auther: auther, + errMsg: "", + }, + { + description: "custom base transport", + base: &http.Transport{}, + source: source, + auther: auther, + errMsg: "", + }, + { + description: "nil source", + base: nil, + source: nil, + auther: auther, + errMsg: "oauth1: Transport's source is nil", + }, + { + description: "empty source", + base: nil, + source: StaticTokenSource(nil), + auther: auther, + errMsg: "oauth1: Token is nil", + }, + { + description: "nil auther", + base: nil, + source: source, + auther: nil, + errMsg: "oauth1: Transport's auther is nil", + }, + } { + tr, err := NewTransport(tc.base, tc.source, tc.auther) + if tc.errMsg == "" { + assert.Nil(t, err) + if tc.base == nil { + assert.Equal(t, http.DefaultTransport, tr.base()) + } else { + assert.Equal(t, tc.base, tr.base()) + } + } else { + assert.Contains(t, err.Error(), tc.errMsg) + assert.Nil(t, tr) + } + } +} + func TestTransport_defaultBaseTransport(t *testing.T) { tr := &Transport{ Base: nil, } assert.Equal(t, http.DefaultTransport, tr.base()) } - func TestTransport_customBaseTransport(t *testing.T) { expected := &http.Transport{} tr := &Transport{ @@ -66,13 +135,10 @@ func TestTransport_customBaseTransport(t *testing.T) { } func TestTransport_nilSource(t *testing.T) { - tr := &Transport{ - source: nil, - auther: &auther{ - config: &Config{Noncer: &fixedNoncer{"any_nonce"}}, - clock: &fixedClock{time.Unix(123456789, 0)}, - }, - } + tr := newTransport(nil, nil, &DefaultAuther{ + config: &Config{Noncer: &fixedNoncer{"any_nonce"}}, + clock: &fixedClock{time.Unix(123456789, 0)}, + }) client := &http.Client{Transport: tr} resp, err := client.Get("http://example.com") assert.Nil(t, resp) @@ -82,13 +148,10 @@ func TestTransport_nilSource(t *testing.T) { } func TestTransport_emptySource(t *testing.T) { - tr := &Transport{ - source: StaticTokenSource(nil), - auther: &auther{ - config: &Config{Noncer: &fixedNoncer{"any_nonce"}}, - clock: &fixedClock{time.Unix(123456789, 0)}, - }, - } + tr := newTransport(nil, StaticTokenSource(nil), &DefaultAuther{ + config: &Config{Noncer: &fixedNoncer{"any_nonce"}}, + clock: &fixedClock{time.Unix(123456789, 0)}, + }) client := &http.Client{Transport: tr} resp, err := client.Get("http://example.com") assert.Nil(t, resp) @@ -98,10 +161,7 @@ func TestTransport_emptySource(t *testing.T) { } func TestTransport_nilAuther(t *testing.T) { - tr := &Transport{ - source: StaticTokenSource(&Token{}), - auther: nil, - } + tr := newTransport(nil, StaticTokenSource(&Token{}), nil) client := &http.Client{Transport: tr} resp, err := client.Get("http://example.com") assert.Nil(t, resp) diff --git a/verifier.go b/verifier.go index 136a127..d540a02 100644 --- a/verifier.go +++ b/verifier.go @@ -156,7 +156,7 @@ type HMACVerifier struct { // Default signer is HMAC-SHA1. func NewHMACVerifier(c *Config, tokenSecret string) *HMACVerifier { return &HMACVerifier{ - newAuther(c).signer(), + NewDefaultAuther(c).signer(), tokenSecret, } } @@ -180,9 +180,10 @@ func (v *HMACVerifier) Verify(baseString, actualSignature string) error { } // collectRequestParameters collects request parameters from -// 1. the request query, -// 2. the request body provided the body is single part & form encoded & form content type header is set, -// 3. and authorization header. +// 1. the request query, +// 2. the request body provided the body is single part & form encoded & form content type header is set, +// 3. and authorization header. +// // The returned map of collected parameter keys and values follow RFC 5849 3.4.1.3, // except duplicate parameters are not supported. func collectRequestParameters(req *http.Request) (map[string]string, string, error) {