Skip to content

Commit

Permalink
feat(services-bff): Refresh token with polling (#16872)
Browse files Browse the repository at this point in the history
* Implement refresh token lock with polling

* fixes after merge

* Fixes after self review

* Fix imports

* Add tests for token refresh service

* Fix tests

* chore: nx format:write update dirty files

* Add tests

* Prevent flaky test and always clean up redis cache

* Update refresh token flow with logging in mind

* Simplify error service logic

* Fix after accidental remove of refresh

* Add comments

---------

Co-authored-by: andes-it <[email protected]>
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 253bfe7 commit 918cf3b
Show file tree
Hide file tree
Showing 15 changed files with 1,056 additions and 205 deletions.
23 changes: 8 additions & 15 deletions apps/services/bff/src/app/modules/auth/auth.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import jwt from 'jsonwebtoken'
import request from 'supertest'
import { setupTestServer } from '../../../../test/setupTestServer'
import {
mockedTokensResponse as tokensResponse,
SID_VALUE,
SESSION_COOKIE_NAME,
ALGORITM_TYPE,
SESSION_COOKIE_NAME,
SID_VALUE,
getLoginSearchParmsFn,
mockedTokensResponse as tokensResponse,
} from '../../../../test/sharedConstants'
import { environment } from '../../../environment'
import { BffConfig } from '../../bff.config'
import { IdsService } from '../ids/ids.service'
import { ParResponse } from '../ids/ids.types'
Expand Down Expand Up @@ -58,17 +59,9 @@ const parResponse: ParResponse = {
const allowedTargetLinkUri = 'http://test-client.com/testclient'

const mockIdsService = {
getPar: jest.fn().mockResolvedValue({
type: 'success',
data: parResponse,
}),
getTokens: jest.fn().mockResolvedValue({
type: 'success',
data: tokensResponse,
}),
revokeToken: jest.fn().mockResolvedValue({
type: 'success',
}),
getPar: jest.fn().mockResolvedValue(parResponse),
getTokens: jest.fn().mockResolvedValue(tokensResponse),
revokeToken: jest.fn().mockResolvedValue(undefined),
getLoginSearchParams: jest.fn().mockImplementation(getLoginSearchParmsFn),
}

Expand All @@ -89,7 +82,7 @@ describe('AuthController', () => {
})

mockConfig = app.get<ConfigType<typeof BffConfig>>(BffConfig.KEY)
baseUrlWithKey = `${mockConfig.clientBaseUrl}${process.env.BFF_CLIENT_KEY_PATH}`
baseUrlWithKey = `${mockConfig.clientBaseUrl}${environment.keyPath}`

server = request(app.getHttpServer())
})
Expand Down
36 changes: 10 additions & 26 deletions apps/services/bff/src/app/modules/auth/auth.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
UnauthorizedException,
} from '@nestjs/common'
import { ConfigType } from '@nestjs/config'
import { CookieOptions, Request, Response } from 'express'
import type { Request, Response } from 'express'
import jwksClient from 'jwks-rsa'
import { jwtDecode } from 'jwt-decode'

Expand All @@ -26,6 +26,7 @@ import {
CreateErrorQueryStrArgs,
createErrorQueryStr,
} from '../../utils/create-error-query-str'
import { getCookieOptions } from '../../utils/get-cookie-options'
import { validateUri } from '../../utils/validate-uri'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
Expand Down Expand Up @@ -55,17 +56,6 @@ export class AuthService {
this.baseUrl = this.config.ids.issuer
}

private getCookieOptions(): CookieOptions {
return {
httpOnly: true,
secure: true,
// The lax setting allows cookies to be sent on top-level navigations (such as redirects),
// while still providing some protection against CSRF attacks.
sameSite: 'lax',
path: environment.keyPath,
}
}

/**
* Creates the client base URL with the path appended.
*/
Expand Down Expand Up @@ -212,12 +202,8 @@ export class AuthService {
prompt,
})

if (parResponse.type === 'error') {
throw parResponse.data
}

searchParams = new URLSearchParams({
request_uri: parResponse.data.request_uri,
request_uri: parResponse.request_uri,
client_id: this.config.ids.clientId,
})
} else {
Expand Down Expand Up @@ -297,13 +283,7 @@ export class AuthService {
codeVerifier: loginAttemptData.codeVerifier,
})

if (tokenResponse.type === 'error') {
throw tokenResponse.data
}

const updatedTokenResponse = await this.updateTokenCache(
tokenResponse.data,
)
const updatedTokenResponse = await this.updateTokenCache(tokenResponse)

// Clean up the login attempt from the cache since we have a successful login.
this.cacheService
Expand All @@ -312,11 +292,15 @@ export class AuthService {
this.logger.warn(err)
})

// Clear any existing session cookie first
// This prevents multiple session cookies being set.
res.clearCookie(SESSION_COOKIE_NAME, getCookieOptions())

// Create session cookie with successful login session id
res.cookie(
SESSION_COOKIE_NAME,
updatedTokenResponse.userProfile.sid,
this.getCookieOptions(),
getCookieOptions(),
)

// Check if there is an old session cookie and clean up the cache
Expand Down Expand Up @@ -424,7 +408,7 @@ export class AuthService {
* - Delete the current login from the cache
* - Clear the session cookie
*/
res.clearCookie(SESSION_COOKIE_NAME, this.getCookieOptions())
res.clearCookie(SESSION_COOKIE_NAME, getCookieOptions())

this.cacheService
.delete(currentLoginCacheKey)
Expand Down
258 changes: 258 additions & 0 deletions apps/services/bff/src/app/modules/auth/token-refresh.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import { LOGGER_PROVIDER } from '@island.is/logging'
import { Test } from '@nestjs/testing'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
import { TokenResponse } from '../ids/ids.types'
import { AuthService } from './auth.service'
import { CachedTokenResponse } from './auth.types'
import { TokenRefreshService } from './token-refresh.service'

const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))

jest.mock('uuid', () => ({
v4: jest.fn().mockReturnValue('fake_uuid'),
}))

const mockLogger = {
error: jest.fn(),
warn: jest.fn(),
}

const mockCacheStore = new Map()

const mockTokenResponse: CachedTokenResponse = {
id_token: 'mock.id.token',
expires_in: 3600,
token_type: 'Bearer',
scope: 'openid profile offline_access',
scopes: ['openid', 'profile', 'offline_access'],
userProfile: {
sid: 'test-session-id',
nationalId: '1234567890',
name: 'Test User',
idp: 'test-idp',
subjectType: 'person',
delegationType: [],
locale: 'is',
birthdate: '1990-01-01',
},
accessTokenExp: Date.now() + 3600000, // Current time + 1 hour in milliseconds
encryptedAccessToken: 'encrypted.access.token',
encryptedRefreshToken: 'encrypted.refresh.token',
}

// When mocking IdsService.refreshToken response, we need TokenResponse type:
const mockIdsTokenResponse: TokenResponse = {
id_token: 'mock.id.token',
access_token: 'mock.access.token',
refresh_token: 'mock.refresh.token',
expires_in: 3600,
token_type: 'Bearer',
scope: 'openid profile offline_access',
}

describe('TokenRefreshService', () => {
let service: TokenRefreshService
let authService: AuthService
let idsService: IdsService
let cacheService: CacheService
const testSid = 'test-sid'
const testRefreshToken = 'test-refresh-token'
const refreshInProgressPrefix = 'refresh_token_in_progress'
const refreshInProgressKey = `${refreshInProgressPrefix}:${testSid}`

beforeEach(async () => {
const module = await Test.createTestingModule({
providers: [
TokenRefreshService,
{
provide: LOGGER_PROVIDER,
useValue: mockLogger,
},
{
provide: AuthService,
useValue: {
updateTokenCache: jest.fn().mockResolvedValue(mockTokenResponse),
},
},
{
provide: IdsService,
useValue: {
refreshToken: jest.fn().mockResolvedValue(mockTokenResponse),
},
},
{
provide: CacheService,
useValue: {
save: jest.fn().mockImplementation(async ({ key, value }) => {
mockCacheStore.set(key, value)
}),
get: jest
.fn()
.mockImplementation(async (key) => mockCacheStore.get(key)),
delete: jest
.fn()
.mockImplementation(async (key) => mockCacheStore.delete(key)),
createSessionKeyType: jest.fn((type, sid) => `${type}_${sid}`),
},
},
],
}).compile()

service = module.get<TokenRefreshService>(TokenRefreshService)
authService = module.get<AuthService>(AuthService)
idsService = module.get<IdsService>(IdsService)
cacheService = module.get<CacheService>(CacheService)
})

afterEach(() => {
mockCacheStore.clear()
jest.clearAllMocks()
})

describe('refreshToken', () => {
it('should successfully refresh token when no refresh is in progress', async () => {
// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken)
expect(authService.updateTokenCache).toHaveBeenCalledWith(
mockTokenResponse,
)
expect(result).toEqual(mockTokenResponse)
})

it('should wait for ongoing refresh and return cached result', async () => {
// Arrange
await cacheService.save({
key: refreshInProgressKey,
value: true,
ttl: 3000,
})

// Simulate another service updating the token while we wait
setTimeout(async () => {
await cacheService.delete(refreshInProgressKey)
await cacheService.save({
key: `current_${testSid}`,
value: mockTokenResponse,
ttl: 3600,
})
}, 500)

// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(result).toEqual(mockTokenResponse)
expect(idsService.refreshToken).not.toHaveBeenCalled()
})

it('should retry refresh if polling times out', async () => {
// Arrange
await cacheService.save({
key: refreshInProgressKey,
value: true,
ttl: 3000,
})

// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(mockLogger.warn).toHaveBeenCalled()
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken)
expect(result).toEqual(mockTokenResponse)
})

it('should handle refresh token failure', async () => {
// Arrange
const error = new Error('Refresh token failed')
jest.spyOn(idsService, 'refreshToken').mockRejectedValueOnce(error)

// Act
const cachedTokenResponse = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})
//
expect(cachedTokenResponse).toBe(null)

expect(mockLogger.warn).toHaveBeenCalledWith(
`Token refresh failed for sid: ${testSid}`,
)
})

it('should prevent concurrent refresh token requests', async () => {
// Arrange
const refreshPromises = []
const refreshCount = 5
let firstRequestStarted = false

// Mock cache.get to make sure first request get in progress lock and other requests waits
jest.spyOn(cacheService, 'get').mockImplementation(async (key) => {
if (key.includes(refreshInProgressPrefix)) {
return firstRequestStarted
}
return mockTokenResponse
})

// Mock cache.save to track first request
jest.spyOn(cacheService, 'save').mockImplementation(async ({ key }) => {
if (key.includes(refreshInProgressPrefix)) {
firstRequestStarted = true
// Add delay after setting lock
await delay(50)
}
})

// Mock cache.delete to clear the lock
jest.spyOn(cacheService, 'delete').mockImplementation(async (key) => {
if (key.includes(refreshInProgressPrefix)) {
firstRequestStarted = false
}
})

// Act
// First request
refreshPromises.push(
service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
}),
)

// Wait a tick to ensure first request starts
await delay(10)

// Remaining requests
for (let i = 1; i < refreshCount; i++) {
refreshPromises.push(
service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
}),
)
}

// Wait for all promises to resolve
const results = await Promise.all(refreshPromises)

// Assert
expect(idsService.refreshToken).toHaveBeenCalledTimes(1)
results.forEach((result) => {
expect(result).toEqual(mockTokenResponse)
})
})
})
})
Loading

0 comments on commit 918cf3b

Please sign in to comment.