Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/export auther #3

Merged
merged 10 commits into from
Aug 30, 2023
27 changes: 14 additions & 13 deletions auther.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ type clock interface {
Now() time.Time
}

// auther adds an "OAuth" Authorization header field to requests.
type auther struct {
// DefaultAuther adds an "OAuth" Authorization header field to requests.
type DefaultAuther struct {
config *Config
clock clock
}

func newAuther(config *Config) *auther {
// NewDefaultAuther returns a new DefaultAuther
func NewDefaultAuther(config *Config) *DefaultAuther {
if config == nil {
config = &Config{}
}
if config.Noncer == nil {
config.Noncer = Base64Noncer{}
}
return &auther{
return &DefaultAuther{
config: config,
}
}

// setRequestTokenAuthHeader adds the OAuth1 header for the request token
// request (temporary credential) according to RFC 5849 2.1.
func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {
func (a *DefaultAuther) setRequestTokenAuthHeader(req *http.Request) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthCallbackParam] = a.config.CallbackURL
params, err := collectParameters(req, oauthParams)
Expand All @@ -78,7 +79,7 @@ func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {

// setAccessTokenAuthHeader sets the OAuth1 header for the access token request
// (token credential) according to RFC 5849 2.3.
func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
func (a *DefaultAuther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = requestToken
oauthParams[oauthVerifierParam] = verifier
Expand All @@ -96,9 +97,9 @@ func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, reque
return nil
}

// setRequestAuthHeader sets the OAuth1 header for making authenticated
// SetRequestAuthHeader sets the OAuth1 header for making authenticated
// requests with an AccessToken (token credential) according to RFC 5849 3.1.
func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error {
func (a *DefaultAuther) SetRequestAuthHeader(req *http.Request, accessToken *Token) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = accessToken.Token
params, err := collectParameters(req, oauthParams)
Expand All @@ -119,7 +120,7 @@ func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) err
// excluding the oauth_signature parameter. This includes the realm parameter
// if it was set in the config. The realm parameter will not be included in
// the signature base string as specified in RFC 5849 3.4.1.3.1.
func (a *auther) commonOAuthParams() map[string]string {
func (a *DefaultAuther) commonOAuthParams() map[string]string {
params := map[string]string{
oauthConsumerKeyParam: a.config.ConsumerKey,
oauthSignatureMethodParam: a.signer().Name(),
Expand All @@ -134,20 +135,20 @@ func (a *auther) commonOAuthParams() map[string]string {
}

// Returns a nonce using the configured Noncer.
func (a *auther) nonce() string {
func (a *DefaultAuther) nonce() string {
return a.config.Noncer.Nonce()
}

// Returns the Unix epoch seconds.
func (a *auther) epoch() int64 {
func (a *DefaultAuther) epoch() int64 {
if a.clock != nil {
return a.clock.Now().Unix()
}
return time.Now().Unix()
}

// Returns the Config's Signer or the default Signer.
func (a *auther) signer() Signer {
func (a *DefaultAuther) signer() Signer {
if a.config.Signer != nil {
return a.config.Signer
}
Expand Down Expand Up @@ -193,7 +194,7 @@ func sortParameters(params map[string]string, format string) []string {
return pairs
}

// collectParameters collects request parameters from the request query, OAuth
// CollectParameters collects request parameters from the request query, OAuth
// parameters (which should exclude oauth_signature), and the request body
// provided the body is single part, form encoded, and the form content type
// header is set. The returned map of collected parameter keys and values
Expand Down
22 changes: 11 additions & 11 deletions auther_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (

func TestCommonOAuthParams(t *testing.T) {
cases := []struct {
auther *auther
DefaultAuther *DefaultAuther
expectedParams map[string]string
}{
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Noncer: &fixedNoncer{"some_nonce"},
Expand All @@ -32,7 +32,7 @@ func TestCommonOAuthParams(t *testing.T) {
},
},
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Realm: "photos",
Expand All @@ -52,13 +52,13 @@ func TestCommonOAuthParams(t *testing.T) {
}

for _, c := range cases {
assert.Equal(t, c.expectedParams, c.auther.commonOAuthParams())
assert.Equal(t, c.expectedParams, c.DefaultAuther.commonOAuthParams())
}
}

func TestNonce(t *testing.T) {
auther := newAuther(nil)
nonce := auther.nonce()
DefaultAuther := NewDefaultAuther(nil)
nonce := DefaultAuther.nonce()
// assert that 32 bytes (256 bites) become 44 bytes since a base64 byte
// zeros the 2 high bits. 3 bytes convert to 4 base64 bytes, 40 base64 bytes
// represent the first 30 of 32 bytes, = padding adds another 4 byte group.
Expand All @@ -67,17 +67,17 @@ func TestNonce(t *testing.T) {
}

func TestEpoch(t *testing.T) {
a := newAuther(nil)
a := NewDefaultAuther(nil)
// assert that a real time is used by default
assert.InEpsilon(t, time.Now().Unix(), a.epoch(), 1)
// assert that the fixed clock can be used for testing
a = &auther{clock: &fixedClock{time.Unix(50037133, 0)}}
a = &DefaultAuther{clock: &fixedClock{time.Unix(50037133, 0)}}
assert.Equal(t, int64(50037133), a.epoch())
}

func TestSigner_Default(t *testing.T) {
config := &Config{ConsumerSecret: "consumer_secret"}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha1 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "BE0uILOruKfSXd4UzYlLJDfOq08="
// assert that the default signer produces the expected HMAC-SHA1 digest
Expand All @@ -92,7 +92,7 @@ func TestSigner_SHA256(t *testing.T) {
config := &Config{
Signer: &HMAC256Signer{ConsumerSecret: "consumer_secret"},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha256 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "pW9drXUyErU8DASWbsP2I3XZbju37AW+VzcGdYSeMo8="
// assert that the signer produces the expected HMAC-SHA256 digest
Expand All @@ -118,7 +118,7 @@ func TestSigner_Custom(t *testing.T) {
ConsumerSecret: "consumer_secret",
Signer: &identitySigner{},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// assert that the custom signer is used
method := a.signer().Name()
digest, err := a.signer().Sign("secret", "hello world")
Expand Down
8 changes: 4 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
func NewClient(ctx context.Context, config *Config, token *Token) *http.Client {
transport := &Transport{
Base: contextTransport(ctx),
source: StaticTokenSource(token),
auther: newAuther(config),
Source: StaticTokenSource(token),
Auther: NewDefaultAuther(config),
}
return &http.Client{Transport: transport}
}
Expand All @@ -69,7 +69,7 @@ func (c *Config) RequestToken() (requestToken, requestSecret string, err error)
if err != nil {
return "", "", err
}
err = newAuther(c).setRequestTokenAuthHeader(req)
err = NewDefaultAuther(c).setRequestTokenAuthHeader(req)
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (acce
if err != nil {
return "", "", err
}
err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
err = NewDefaultAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
if err != nil {
return "", "", err
}
Expand Down
12 changes: 6 additions & 6 deletions reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestTwitterRequestTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.RequestTokenURL, nil)
assert.Nil(t, err)
err = auther.setRequestTokenAuthHeader(req)
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestTwitterAccessTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.AccessTokenURL, nil)
assert.Nil(t, err)
err = auther.setAccessTokenAuthHeader(req, expectedRequestToken, requestTokenSecret, expectedVerifier)
Expand Down Expand Up @@ -111,7 +111,7 @@ var twitterConfig = &Config{
}

func TestTwitterParameterString(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -128,7 +128,7 @@ func TestTwitterParameterString(t *testing.T) {
}

func TestTwitterSignatureBase(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -151,15 +151,15 @@ func TestTwitterRequestAuthHeader(t *testing.T) {
expectedSignature := PercentEncode("tnnArxj06cWHq44gCs1OSKk/jLY=")
expectedTimestamp := "1318622958"

auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")

accessToken := &Token{expectedTwitterOAuthToken, oauthTokenSecret}
req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode()))
assert.Nil(t, err)
req.Header.Set(contentType, formContentType)
err = auther.setRequestAuthHeader(req, accessToken)
err = auther.SetRequestAuthHeader(req, accessToken)
// assert that request is signed and has an access token token
assert.Nil(t, err)
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam))
Expand Down
20 changes: 12 additions & 8 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"net/http"
)

type Auther interface {
SetRequestAuthHeader(req *http.Request, accessToken *Token) error
}

// Transport is an http.RoundTripper which makes OAuth1 HTTP requests. It
// wraps a base RoundTripper and adds an Authorization header using the
// token from a TokenSource.
Expand All @@ -15,28 +19,28 @@ type Transport struct {
// Base is the base RoundTripper used to make HTTP requests. If nil, then
// http.DefaultTransport is used
Base http.RoundTripper
// source supplies the token to use when signing a request
source TokenSource
// auther adds OAuth1 Authorization headers to requests
auther *auther
// Source supplies the token to use when signing a request
Source TokenSource
// Auther adds OAuth1 Authorization headers to requests
hhuang-rayark marked this conversation as resolved.
Show resolved Hide resolved
Auther Auther
raymond-chia marked this conversation as resolved.
Show resolved Hide resolved
}

// RoundTrip authorizes the request with a signed OAuth1 Authorization header
// using the auther and TokenSource.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.source == nil {
if t.Source == nil {
return nil, fmt.Errorf("oauth1: Transport's source is nil")
}
accessToken, err := t.source.Token()
accessToken, err := t.Source.Token()
if err != nil {
return nil, err
}
if t.auther == nil {
if t.Auther == nil {
return nil, fmt.Errorf("oauth1: Transport's auther is nil")
}
// RoundTripper should not modify the given request, clone it
req2 := cloneRequest(req)
err = t.auther.setRequestAuthHeader(req2, accessToken)
err = t.Auther.SetRequestAuthHeader(req2, accessToken)
if err != nil {
return nil, err
}
Expand Down
20 changes: 11 additions & 9 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/stretchr/testify/assert"
)

var _ Auther = &DefaultAuther{}

func TestTransport(t *testing.T) {
const (
expectedToken = "access_token"
Expand All @@ -34,13 +36,13 @@ func TestTransport(t *testing.T) {
ConsumerSecret: "consumer_secret",
Noncer: &fixedNoncer{expectedNonce},
}
auther := &auther{
auther := &DefaultAuther{
config: config,
clock: &fixedClock{time.Unix(123456789, 0)},
}
tr := &Transport{
source: StaticTokenSource(NewToken(expectedToken, "some_secret")),
auther: auther,
Source: StaticTokenSource(NewToken(expectedToken, "some_secret")),
Auther: auther,
}
client := &http.Client{Transport: tr}

Expand All @@ -67,8 +69,8 @@ func TestTransport_customBaseTransport(t *testing.T) {

func TestTransport_nilSource(t *testing.T) {
tr := &Transport{
source: nil,
auther: &auther{
Source: nil,
Auther: &DefaultAuther{
config: &Config{Noncer: &fixedNoncer{"any_nonce"}},
clock: &fixedClock{time.Unix(123456789, 0)},
},
Expand All @@ -83,8 +85,8 @@ func TestTransport_nilSource(t *testing.T) {

func TestTransport_emptySource(t *testing.T) {
tr := &Transport{
source: StaticTokenSource(nil),
auther: &auther{
Source: StaticTokenSource(nil),
Auther: &DefaultAuther{
config: &Config{Noncer: &fixedNoncer{"any_nonce"}},
clock: &fixedClock{time.Unix(123456789, 0)},
},
Expand All @@ -99,8 +101,8 @@ func TestTransport_emptySource(t *testing.T) {

func TestTransport_nilAuther(t *testing.T) {
tr := &Transport{
source: StaticTokenSource(&Token{}),
auther: nil,
Source: StaticTokenSource(&Token{}),
Auther: nil,
}
client := &http.Client{Transport: tr}
resp, err := client.Get("http://example.com")
Expand Down
9 changes: 5 additions & 4 deletions verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ type HMACVerifier struct {
// Default signer is HMAC-SHA1.
func NewHMACVerifier(c *Config, tokenSecret string) *HMACVerifier {
return &HMACVerifier{
newAuther(c).signer(),
NewDefaultAuther(c).signer(),
tokenSecret,
}
}
Expand All @@ -180,9 +180,10 @@ func (v *HMACVerifier) Verify(baseString, actualSignature string) error {
}

// collectRequestParameters collects request parameters from
// 1. the request query,
// 2. the request body provided the body is single part & form encoded & form content type header is set,
// 3. and authorization header.
// 1. the request query,
// 2. the request body provided the body is single part & form encoded & form content type header is set,
// 3. and authorization header.
//
// The returned map of collected parameter keys and values follow RFC 5849 3.4.1.3,
// except duplicate parameters are not supported.
func collectRequestParameters(req *http.Request) (map[string]string, string, error) {
Expand Down