-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(services-bff): Refresh token with polling (#16872)
* Implement refresh token lock with polling * fixes after merge * Fixes after self review * Fix imports * Add tests for token refresh service * Fix tests * chore: nx format:write update dirty files * Add tests * Prevent flaky test and always clean up redis cache * Update refresh token flow with logging in mind * Simplify error service logic * Fix after accidental remove of refresh * Add comments --------- Co-authored-by: andes-it <[email protected]> Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
- Loading branch information
1 parent
253bfe7
commit 918cf3b
Showing
15 changed files
with
1,056 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
258 changes: 258 additions & 0 deletions
258
apps/services/bff/src/app/modules/auth/token-refresh.service.spec.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
import { LOGGER_PROVIDER } from '@island.is/logging' | ||
import { Test } from '@nestjs/testing' | ||
import { CacheService } from '../cache/cache.service' | ||
import { IdsService } from '../ids/ids.service' | ||
import { TokenResponse } from '../ids/ids.types' | ||
import { AuthService } from './auth.service' | ||
import { CachedTokenResponse } from './auth.types' | ||
import { TokenRefreshService } from './token-refresh.service' | ||
|
||
const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) | ||
|
||
jest.mock('uuid', () => ({ | ||
v4: jest.fn().mockReturnValue('fake_uuid'), | ||
})) | ||
|
||
const mockLogger = { | ||
error: jest.fn(), | ||
warn: jest.fn(), | ||
} | ||
|
||
const mockCacheStore = new Map() | ||
|
||
const mockTokenResponse: CachedTokenResponse = { | ||
id_token: 'mock.id.token', | ||
expires_in: 3600, | ||
token_type: 'Bearer', | ||
scope: 'openid profile offline_access', | ||
scopes: ['openid', 'profile', 'offline_access'], | ||
userProfile: { | ||
sid: 'test-session-id', | ||
nationalId: '1234567890', | ||
name: 'Test User', | ||
idp: 'test-idp', | ||
subjectType: 'person', | ||
delegationType: [], | ||
locale: 'is', | ||
birthdate: '1990-01-01', | ||
}, | ||
accessTokenExp: Date.now() + 3600000, // Current time + 1 hour in milliseconds | ||
encryptedAccessToken: 'encrypted.access.token', | ||
encryptedRefreshToken: 'encrypted.refresh.token', | ||
} | ||
|
||
// When mocking IdsService.refreshToken response, we need TokenResponse type: | ||
const mockIdsTokenResponse: TokenResponse = { | ||
id_token: 'mock.id.token', | ||
access_token: 'mock.access.token', | ||
refresh_token: 'mock.refresh.token', | ||
expires_in: 3600, | ||
token_type: 'Bearer', | ||
scope: 'openid profile offline_access', | ||
} | ||
|
||
describe('TokenRefreshService', () => { | ||
let service: TokenRefreshService | ||
let authService: AuthService | ||
let idsService: IdsService | ||
let cacheService: CacheService | ||
const testSid = 'test-sid' | ||
const testRefreshToken = 'test-refresh-token' | ||
const refreshInProgressPrefix = 'refresh_token_in_progress' | ||
const refreshInProgressKey = `${refreshInProgressPrefix}:${testSid}` | ||
|
||
beforeEach(async () => { | ||
const module = await Test.createTestingModule({ | ||
providers: [ | ||
TokenRefreshService, | ||
{ | ||
provide: LOGGER_PROVIDER, | ||
useValue: mockLogger, | ||
}, | ||
{ | ||
provide: AuthService, | ||
useValue: { | ||
updateTokenCache: jest.fn().mockResolvedValue(mockTokenResponse), | ||
}, | ||
}, | ||
{ | ||
provide: IdsService, | ||
useValue: { | ||
refreshToken: jest.fn().mockResolvedValue(mockTokenResponse), | ||
}, | ||
}, | ||
{ | ||
provide: CacheService, | ||
useValue: { | ||
save: jest.fn().mockImplementation(async ({ key, value }) => { | ||
mockCacheStore.set(key, value) | ||
}), | ||
get: jest | ||
.fn() | ||
.mockImplementation(async (key) => mockCacheStore.get(key)), | ||
delete: jest | ||
.fn() | ||
.mockImplementation(async (key) => mockCacheStore.delete(key)), | ||
createSessionKeyType: jest.fn((type, sid) => `${type}_${sid}`), | ||
}, | ||
}, | ||
], | ||
}).compile() | ||
|
||
service = module.get<TokenRefreshService>(TokenRefreshService) | ||
authService = module.get<AuthService>(AuthService) | ||
idsService = module.get<IdsService>(IdsService) | ||
cacheService = module.get<CacheService>(CacheService) | ||
}) | ||
|
||
afterEach(() => { | ||
mockCacheStore.clear() | ||
jest.clearAllMocks() | ||
}) | ||
|
||
describe('refreshToken', () => { | ||
it('should successfully refresh token when no refresh is in progress', async () => { | ||
// Act | ||
const result = await service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}) | ||
|
||
// Assert | ||
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken) | ||
expect(authService.updateTokenCache).toHaveBeenCalledWith( | ||
mockTokenResponse, | ||
) | ||
expect(result).toEqual(mockTokenResponse) | ||
}) | ||
|
||
it('should wait for ongoing refresh and return cached result', async () => { | ||
// Arrange | ||
await cacheService.save({ | ||
key: refreshInProgressKey, | ||
value: true, | ||
ttl: 3000, | ||
}) | ||
|
||
// Simulate another service updating the token while we wait | ||
setTimeout(async () => { | ||
await cacheService.delete(refreshInProgressKey) | ||
await cacheService.save({ | ||
key: `current_${testSid}`, | ||
value: mockTokenResponse, | ||
ttl: 3600, | ||
}) | ||
}, 500) | ||
|
||
// Act | ||
const result = await service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}) | ||
|
||
// Assert | ||
expect(result).toEqual(mockTokenResponse) | ||
expect(idsService.refreshToken).not.toHaveBeenCalled() | ||
}) | ||
|
||
it('should retry refresh if polling times out', async () => { | ||
// Arrange | ||
await cacheService.save({ | ||
key: refreshInProgressKey, | ||
value: true, | ||
ttl: 3000, | ||
}) | ||
|
||
// Act | ||
const result = await service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}) | ||
|
||
// Assert | ||
expect(mockLogger.warn).toHaveBeenCalled() | ||
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken) | ||
expect(result).toEqual(mockTokenResponse) | ||
}) | ||
|
||
it('should handle refresh token failure', async () => { | ||
// Arrange | ||
const error = new Error('Refresh token failed') | ||
jest.spyOn(idsService, 'refreshToken').mockRejectedValueOnce(error) | ||
|
||
// Act | ||
const cachedTokenResponse = await service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}) | ||
// | ||
expect(cachedTokenResponse).toBe(null) | ||
|
||
expect(mockLogger.warn).toHaveBeenCalledWith( | ||
`Token refresh failed for sid: ${testSid}`, | ||
) | ||
}) | ||
|
||
it('should prevent concurrent refresh token requests', async () => { | ||
// Arrange | ||
const refreshPromises = [] | ||
const refreshCount = 5 | ||
let firstRequestStarted = false | ||
|
||
// Mock cache.get to make sure first request get in progress lock and other requests waits | ||
jest.spyOn(cacheService, 'get').mockImplementation(async (key) => { | ||
if (key.includes(refreshInProgressPrefix)) { | ||
return firstRequestStarted | ||
} | ||
return mockTokenResponse | ||
}) | ||
|
||
// Mock cache.save to track first request | ||
jest.spyOn(cacheService, 'save').mockImplementation(async ({ key }) => { | ||
if (key.includes(refreshInProgressPrefix)) { | ||
firstRequestStarted = true | ||
// Add delay after setting lock | ||
await delay(50) | ||
} | ||
}) | ||
|
||
// Mock cache.delete to clear the lock | ||
jest.spyOn(cacheService, 'delete').mockImplementation(async (key) => { | ||
if (key.includes(refreshInProgressPrefix)) { | ||
firstRequestStarted = false | ||
} | ||
}) | ||
|
||
// Act | ||
// First request | ||
refreshPromises.push( | ||
service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}), | ||
) | ||
|
||
// Wait a tick to ensure first request starts | ||
await delay(10) | ||
|
||
// Remaining requests | ||
for (let i = 1; i < refreshCount; i++) { | ||
refreshPromises.push( | ||
service.refreshToken({ | ||
sid: testSid, | ||
encryptedRefreshToken: testRefreshToken, | ||
}), | ||
) | ||
} | ||
|
||
// Wait for all promises to resolve | ||
const results = await Promise.all(refreshPromises) | ||
|
||
// Assert | ||
expect(idsService.refreshToken).toHaveBeenCalledTimes(1) | ||
results.forEach((result) => { | ||
expect(result).toEqual(mockTokenResponse) | ||
}) | ||
}) | ||
}) | ||
}) |
Oops, something went wrong.