diff --git a/src/auth0-session/handlers/callback.ts b/src/auth0-session/handlers/callback.ts index 8328afbd0..d5a7b6f94 100644 --- a/src/auth0-session/handlers/callback.ts +++ b/src/auth0-session/handlers/callback.ts @@ -1,7 +1,7 @@ import { IncomingMessage, ServerResponse } from 'http'; import urlJoin from 'url-join'; import { BadRequest } from 'http-errors'; -import { Config } from '../config'; +import { AuthorizationParameters, Config } from '../config'; import { ClientFactory } from '../client'; import TransientStore from '../transient-store'; import { decodeState } from '../hooks/get-login-state'; @@ -17,6 +17,8 @@ export type CallbackOptions = { afterCallback?: AfterCallback; redirectUri?: string; + + authorizationParams?: Partial; }; export type HandleCallback = (req: IncomingMessage, res: ServerResponse, options?: CallbackOptions) => Promise; @@ -40,12 +42,17 @@ export default function callbackHandlerFactory( const code_verifier = transientCookieHandler.read('code_verifier', req, res); const nonce = transientCookieHandler.read('nonce', req, res); - tokenSet = await client.callback(redirectUri, callbackParams, { - max_age: max_age !== undefined ? +max_age : undefined, - code_verifier, - nonce, - state: expectedState - }); + tokenSet = await client.callback( + redirectUri, + callbackParams, + { + max_age: max_age !== undefined ? +max_age : undefined, + code_verifier, + nonce, + state: expectedState + }, + { exchangeBody: options?.authorizationParams } + ); } catch (err) { throw new BadRequest(err.message); } diff --git a/src/handlers/callback.ts b/src/handlers/callback.ts index 95a7a0be3..38bb45b86 100644 --- a/src/handlers/callback.ts +++ b/src/handlers/callback.ts @@ -1,6 +1,6 @@ import { strict as assert } from 'assert'; import { NextApiResponse, NextApiRequest } from 'next'; -import { HandleCallback as BaseHandleCallback } from '../auth0-session'; +import { AuthorizationParameters, HandleCallback as BaseHandleCallback } from '../auth0-session'; import { Session } from '../session'; import { assertReqRes } from '../utils/assert'; import { NextConfig } from '../config'; @@ -85,6 +85,11 @@ export interface CallbackOptions { * organizations, it should match {@Link LoginOptions.authorizationParams}. */ organization?: string; + + /** + * This is useful for sending custom query parameters in the body of the code exchange request for use in rules. + */ + authorizationParams?: Partial; } /** diff --git a/src/session/get-access-token.ts b/src/session/get-access-token.ts index 09ba11011..05d2062c1 100644 --- a/src/session/get-access-token.ts +++ b/src/session/get-access-token.ts @@ -4,7 +4,7 @@ import { ClientFactory } from '../auth0-session'; import { AccessTokenError } from '../utils/errors'; import { intersect, match } from '../utils/array'; import { Session, SessionCache, fromTokenSet } from '../session'; -import { NextConfig } from '../config'; +import { AuthorizationParameters, NextConfig } from '../config'; export type AfterRefresh = (req: NextApiRequest, res: NextApiResponse, session: Session) => Promise | Session; @@ -52,6 +52,11 @@ export interface AccessTokenRequest { * ``` */ afterRefresh?: AfterRefresh; + + /** + * This is useful for sending custom query parameters in the body of the refresh grant request for use in rules. + */ + authorizationParams?: Partial; } /** @@ -140,7 +145,9 @@ export default function accessTokenFactory( (session.refreshToken && accessTokenRequest && accessTokenRequest.refresh) ) { const client = await getClient(); - const tokenSet = await client.refresh(session.refreshToken); + const tokenSet = await client.refresh(session.refreshToken, { + exchangeBody: accessTokenRequest?.authorizationParams + }); // Update the session. const newSession = fromTokenSet(tokenSet, config); diff --git a/tests/handlers/callback.test.ts b/tests/handlers/callback.test.ts index 235f66abe..5b2a5bba8 100644 --- a/tests/handlers/callback.test.ts +++ b/tests/handlers/callback.test.ts @@ -6,6 +6,7 @@ import { get, post, toSignedCookieJar } from '../auth0-session/fixtures/helpers' import { encodeState } from '../../src/auth0-session/hooks/get-login-state'; import { setup, teardown } from '../fixtures/setup'; import { Session, AfterCallback } from '../../src'; +import nock from 'nock'; const callback = (baseUrl: string, body: any, cookieJar?: CookieJar): Promise => post(baseUrl, `/api/auth/callback`, { @@ -361,4 +362,46 @@ describe('callback handler', () => { expect(session.user.org_id).toEqual('foo'); }); + + test('should pass custom params to the token exchange', async () => { + const baseUrl = await setup(withoutApi, { + callbackOptions: { + authorizationParams: { foo: 'bar' } + } + }); + const state = encodeState({ returnTo: baseUrl }); + const cookieJar = toSignedCookieJar( + { + state, + nonce: '__test_nonce__' + }, + baseUrl + ); + const spy = jest.fn(); + + nock(`${withoutApi.issuerBaseURL}`) + .post('/oauth/token', /grant_type=authorization_code/) + .reply(200, (_, body) => { + spy(body); + return { + access_token: 'eyJz93a...k4laUWw', + expires_in: 750, + scope: 'read:foo delete:foo', + refresh_token: 'GEbRxBN...edjnXbL', + id_token: makeIdToken({ iss: `${withoutApi.issuerBaseURL}/` }), + token_type: 'Bearer' + }; + }); + + const { res } = await callback( + baseUrl, + { + state, + code: 'foobar' + }, + cookieJar + ); + expect(res.statusCode).toBe(302); + expect(spy).toHaveBeenCalledWith(expect.stringContaining('foo=bar')); + }); }); diff --git a/tests/session/get-access-token.test.ts b/tests/session/get-access-token.test.ts index 80c2a5334..22cab24e7 100644 --- a/tests/session/get-access-token.test.ts +++ b/tests/session/get-access-token.test.ts @@ -3,6 +3,8 @@ import { withApi } from '../fixtures/default-settings'; import { get } from '../auth0-session/fixtures/helpers'; import { Session } from '../../src'; import { refreshTokenExchange, refreshTokenRotationExchange } from '../fixtures/oidc-nocks'; +import { makeIdToken } from '../auth0-session/fixtures/cert'; +import nock from 'nock'; describe('get access token', () => { afterEach(teardown); @@ -242,4 +244,36 @@ describe('get access token', () => { const { idToken: newIdToken } = await get(baseUrl, '/api/session', { cookieJar }); expect(newIdToken).toBeUndefined(); }); + + test('should pass custom auth params in refresh grant request body', async () => { + const idToken = makeIdToken({ + iss: `${withApi.issuerBaseURL}/`, + aud: withApi.clientID, + email: 'john@test.com', + name: 'john doe', + sub: '123' + }); + + const spy = jest.fn(); + nock(`${withApi.issuerBaseURL}`) + .post('/oauth/token', /grant_type=refresh_token/) + .reply(200, (_, body) => { + spy(body); + return { + access_token: 'new-token', + id_token: idToken, + token_type: 'Bearer', + expires_in: 750, + scope: 'read:foo write:foo' + }; + }); + + const baseUrl = await setup(withApi, { + getAccessTokenOptions: { refresh: true, authorizationParams: { baz: 'qux' } } + }); + const cookieJar = await login(baseUrl); + const { accessToken } = await get(baseUrl, '/api/access-token', { cookieJar }); + expect(accessToken).toEqual('new-token'); + expect(spy).toHaveBeenCalledWith(expect.stringContaining('baz=qux')); + }); });