Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/export auther #3

Merged
merged 10 commits into from
Aug 30, 2023
25 changes: 13 additions & 12 deletions auther.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ type clock interface {
Now() time.Time
}

// auther adds an "OAuth" Authorization header field to requests.
type auther struct {
// DefaultAuther adds an "OAuth" Authorization header field to requests.
type DefaultAuther struct {
config *Config
clock clock
}

func newAuther(config *Config) *auther {
// NewDefaultAuther returns a new DefaultAuther
func NewDefaultAuther(config *Config) *DefaultAuther {
if config == nil {
config = &Config{}
}
if config.Noncer == nil {
config.Noncer = Base64Noncer{}
}
return &auther{
return &DefaultAuther{
config: config,
}
}

// setRequestTokenAuthHeader adds the OAuth1 header for the request token
// request (temporary credential) according to RFC 5849 2.1.
func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {
func (a *DefaultAuther) setRequestTokenAuthHeader(req *http.Request) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthCallbackParam] = a.config.CallbackURL
params, err := collectParameters(req, oauthParams)
Expand All @@ -78,7 +79,7 @@ func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {

// setAccessTokenAuthHeader sets the OAuth1 header for the access token request
// (token credential) according to RFC 5849 2.3.
func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
func (a *DefaultAuther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = requestToken
oauthParams[oauthVerifierParam] = verifier
Expand All @@ -96,9 +97,9 @@ func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, reque
return nil
}

// setRequestAuthHeader sets the OAuth1 header for making authenticated
// SetRequestAuthHeader sets the OAuth1 header for making authenticated
// requests with an AccessToken (token credential) according to RFC 5849 3.1.
func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error {
func (a *DefaultAuther) SetRequestAuthHeader(req *http.Request, accessToken *Token) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = accessToken.Token
params, err := collectParameters(req, oauthParams)
Expand All @@ -119,7 +120,7 @@ func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) err
// excluding the oauth_signature parameter. This includes the realm parameter
// if it was set in the config. The realm parameter will not be included in
// the signature base string as specified in RFC 5849 3.4.1.3.1.
func (a *auther) commonOAuthParams() map[string]string {
func (a *DefaultAuther) commonOAuthParams() map[string]string {
params := map[string]string{
oauthConsumerKeyParam: a.config.ConsumerKey,
oauthSignatureMethodParam: a.signer().Name(),
Expand All @@ -134,20 +135,20 @@ func (a *auther) commonOAuthParams() map[string]string {
}

// Returns a nonce using the configured Noncer.
func (a *auther) nonce() string {
func (a *DefaultAuther) nonce() string {
return a.config.Noncer.Nonce()
}

// Returns the Unix epoch seconds.
func (a *auther) epoch() int64 {
func (a *DefaultAuther) epoch() int64 {
if a.clock != nil {
return a.clock.Now().Unix()
}
return time.Now().Unix()
}

// Returns the Config's Signer or the default Signer.
func (a *auther) signer() Signer {
func (a *DefaultAuther) signer() Signer {
if a.config.Signer != nil {
return a.config.Signer
}
Expand Down
22 changes: 11 additions & 11 deletions auther_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (

func TestCommonOAuthParams(t *testing.T) {
cases := []struct {
auther *auther
DefaultAuther *DefaultAuther
expectedParams map[string]string
}{
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Noncer: &fixedNoncer{"some_nonce"},
Expand All @@ -32,7 +32,7 @@ func TestCommonOAuthParams(t *testing.T) {
},
},
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Realm: "photos",
Expand All @@ -52,13 +52,13 @@ func TestCommonOAuthParams(t *testing.T) {
}

for _, c := range cases {
assert.Equal(t, c.expectedParams, c.auther.commonOAuthParams())
assert.Equal(t, c.expectedParams, c.DefaultAuther.commonOAuthParams())
}
}

func TestNonce(t *testing.T) {
auther := newAuther(nil)
nonce := auther.nonce()
DefaultAuther := NewDefaultAuther(nil)
nonce := DefaultAuther.nonce()
// assert that 32 bytes (256 bites) become 44 bytes since a base64 byte
// zeros the 2 high bits. 3 bytes convert to 4 base64 bytes, 40 base64 bytes
// represent the first 30 of 32 bytes, = padding adds another 4 byte group.
Expand All @@ -67,17 +67,17 @@ func TestNonce(t *testing.T) {
}

func TestEpoch(t *testing.T) {
a := newAuther(nil)
a := NewDefaultAuther(nil)
// assert that a real time is used by default
assert.InEpsilon(t, time.Now().Unix(), a.epoch(), 1)
// assert that the fixed clock can be used for testing
a = &auther{clock: &fixedClock{time.Unix(50037133, 0)}}
a = &DefaultAuther{clock: &fixedClock{time.Unix(50037133, 0)}}
assert.Equal(t, int64(50037133), a.epoch())
}

func TestSigner_Default(t *testing.T) {
config := &Config{ConsumerSecret: "consumer_secret"}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha1 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "BE0uILOruKfSXd4UzYlLJDfOq08="
// assert that the default signer produces the expected HMAC-SHA1 digest
Expand All @@ -92,7 +92,7 @@ func TestSigner_SHA256(t *testing.T) {
config := &Config{
Signer: &HMAC256Signer{ConsumerSecret: "consumer_secret"},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha256 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "pW9drXUyErU8DASWbsP2I3XZbju37AW+VzcGdYSeMo8="
// assert that the signer produces the expected HMAC-SHA256 digest
Expand All @@ -118,7 +118,7 @@ func TestSigner_Custom(t *testing.T) {
ConsumerSecret: "consumer_secret",
Signer: &identitySigner{},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// assert that the custom signer is used
method := a.signer().Name()
digest, err := a.signer().Sign("secret", "hello world")
Expand Down
10 changes: 3 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {

// NewClient returns a new http Client which signs requests via OAuth1.
func NewClient(ctx context.Context, config *Config, token *Token) *http.Client {
transport := &Transport{
Base: contextTransport(ctx),
source: StaticTokenSource(token),
auther: newAuther(config),
}
transport := newTransport(contextTransport(ctx), StaticTokenSource(token), NewDefaultAuther(config))
return &http.Client{Transport: transport}
}

Expand All @@ -69,7 +65,7 @@ func (c *Config) RequestToken() (requestToken, requestSecret string, err error)
if err != nil {
return "", "", err
}
err = newAuther(c).setRequestTokenAuthHeader(req)
err = NewDefaultAuther(c).setRequestTokenAuthHeader(req)
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -148,7 +144,7 @@ func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (acce
if err != nil {
return "", "", err
}
err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
err = NewDefaultAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
if err != nil {
return "", "", err
}
Expand Down
12 changes: 6 additions & 6 deletions reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestTwitterRequestTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.RequestTokenURL, nil)
assert.Nil(t, err)
err = auther.setRequestTokenAuthHeader(req)
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestTwitterAccessTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.AccessTokenURL, nil)
assert.Nil(t, err)
err = auther.setAccessTokenAuthHeader(req, expectedRequestToken, requestTokenSecret, expectedVerifier)
Expand Down Expand Up @@ -111,7 +111,7 @@ var twitterConfig = &Config{
}

func TestTwitterParameterString(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -128,7 +128,7 @@ func TestTwitterParameterString(t *testing.T) {
}

func TestTwitterSignatureBase(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -151,15 +151,15 @@ func TestTwitterRequestAuthHeader(t *testing.T) {
expectedSignature := PercentEncode("tnnArxj06cWHq44gCs1OSKk/jLY=")
expectedTimestamp := "1318622958"

auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")

accessToken := &Token{expectedTwitterOAuthToken, oauthTokenSecret}
req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode()))
assert.Nil(t, err)
req.Header.Set(contentType, formContentType)
err = auther.setRequestAuthHeader(req, accessToken)
err = auther.SetRequestAuthHeader(req, accessToken)
// assert that request is signed and has an access token token
assert.Nil(t, err)
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam))
Expand Down
50 changes: 43 additions & 7 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"net/http"
)

type Auther interface {
SetRequestAuthHeader(req *http.Request, accessToken *Token) error
}

// Transport is an http.RoundTripper which makes OAuth1 HTTP requests. It
// wraps a base RoundTripper and adds an Authorization header using the
// token from a TokenSource.
Expand All @@ -18,25 +22,43 @@ type Transport struct {
// source supplies the token to use when signing a request
source TokenSource
// auther adds OAuth1 Authorization headers to requests
auther *auther
auther Auther
}

// NewTransport returns a new Transport and an error
func NewTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) (*Transport, error) {
t := newTransport(baseRoundTripper, source, auther)
err := t.checkValid()
if err != nil {
return nil, err
}
return t, nil
}

// newTransport returns a new Transport without checking whether there is an error
raymond-chia marked this conversation as resolved.
Show resolved Hide resolved
// newTransport is only for NewTransport & NewClient
func newTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) *Transport {
return &Transport{
Base: baseRoundTripper,
source: source,
auther: auther,
}
}

// RoundTrip authorizes the request with a signed OAuth1 Authorization header
// using the auther and TokenSource.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.source == nil {
return nil, fmt.Errorf("oauth1: Transport's source is nil")
err := t.checkValid()
if err != nil {
return nil, err
}
accessToken, err := t.source.Token()
if err != nil {
return nil, err
}
if t.auther == nil {
return nil, fmt.Errorf("oauth1: Transport's auther is nil")
}
// RoundTripper should not modify the given request, clone it
req2 := cloneRequest(req)
err = t.auther.setRequestAuthHeader(req2, accessToken)
err = t.auther.SetRequestAuthHeader(req2, accessToken)
if err != nil {
return nil, err
}
Expand All @@ -63,3 +85,17 @@ func cloneRequest(req *http.Request) *http.Request {
}
return r2
}

func (t *Transport) checkValid() error {
if t.source == nil {
return fmt.Errorf("oauth1: Transport's source is nil")
}
if t.auther == nil {
return fmt.Errorf("oauth1: Transport's auther is nil")
}
_, err := t.source.Token()
if err != nil {
return err
}
return nil
}
Loading