From c27e39ff69aa63165de74d71cdc52d1d5ac17418 Mon Sep 17 00:00:00 2001 From: John-Paul Drawneek <163003532+jpd4emis@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:56:48 +0100 Subject: [PATCH] fix: Rotate OAuth token when expired --- src/auth/auth.test.ts | 133 ++++++++++++++++++++++++++++++++++++++++++ src/auth/auth.ts | 64 +++++++------------- 2 files changed, 155 insertions(+), 42 deletions(-) create mode 100644 src/auth/auth.test.ts diff --git a/src/auth/auth.test.ts b/src/auth/auth.test.ts new file mode 100644 index 0000000..058ca7b --- /dev/null +++ b/src/auth/auth.test.ts @@ -0,0 +1,133 @@ +import { mocked } from "jest-mock"; +import { mockServices } from '@backstage/backend-test-utils'; + +import { getAuthToken, loadAuthConfig } from './auth'; +import { RootConfigService } from "@backstage/backend-plugin-api"; + +global.fetch = jest.fn() as jest.Mock; + +function mockedResponse(status: number, body: any): Promise { + return Promise.resolve({ + json: () => Promise.resolve(body), + status + } as Response); +} + +describe('PagerDuty Auth', () => { + const logger = mockServices.rootLogger(); + let config : RootConfigService; + + beforeAll(() => { + jest.useFakeTimers(); + }); + + describe('getAuthToken', () => { + config = mockServices.rootConfig({ + data: { + pagerDuty: { + oauth: { + clientId: 'foobar', + clientSecret: 'super-secret-wow', + subDomain: 'EU', + } + + } + } + }); + + it('Get token with legacy OAuth config', async () => { + mocked(fetch).mockReturnValue( + mockedResponse(200, { access_token: 'sometoken', token_type: "bearer", expires_in: 86400 }) + ); + jest.setSystemTime(new Date(2024, 9, 1, 9, 0)); + await loadAuthConfig(config, logger); + + const result = await getAuthToken(); + expect(result).toEqual('Bearer sometoken'); + }); + + it('Get token with account OAuth config', async () => { + config = mockServices.rootConfig({ + data: { + pagerDuty: { + accounts: [ + { + id: 'test1', + oauth: { + clientId: 'foobar', + clientSecret: 'super-secret-wow', + subDomain: 'EU', + } + } + ] + } + } + }); + mocked(fetch).mockReturnValue( + mockedResponse(200, { access_token: 'sometoken', token_type: "bearer", expires_in: 86400 }) + ); + jest.setSystemTime(new Date(2024, 9, 1, 9, 0)); + await loadAuthConfig(config, logger); + + const defaultResult = await getAuthToken(); + expect(defaultResult).toEqual('Bearer sometoken'); + const accountResult = await getAuthToken('test1'); + expect(accountResult).toEqual('Bearer sometoken'); + }); + + it('Get refreshed token with legacy OAuth config', async () => { + mocked(fetch).mockReturnValueOnce( + mockedResponse(200, { access_token: 'sometoken1', token_type: "bearer", expires_in: 86400 }) + ); + mocked(fetch).mockReturnValueOnce( + mockedResponse(200, { access_token: 'sometoken2', token_type: "bearer", expires_in: 86400 }) + ); + jest.setSystemTime(new Date(2024, 9, 1, 9, 0)); + await loadAuthConfig(config, logger); + + const before = await getAuthToken(); + expect(before).toEqual('Bearer sometoken1'); + + jest.setSystemTime(new Date(2024, 9, 2, 9, 1)); + const result = await getAuthToken(); + expect(result).toEqual('Bearer sometoken2'); + }); + + it('Get legacy token', async () => { + config = mockServices.rootConfig({ + data: { + pagerDuty: { + apiToken: 'some-api-token', + } + } + }); + await loadAuthConfig(config, logger); + + const result = await getAuthToken(); + expect(result).toEqual('Token token=some-api-token'); + }); + + it('Get account token', async () => { + config = mockServices.rootConfig({ + data: { + pagerDuty: { + accounts: [ + { + id: 'test2', + apiToken: 'some-api-token', + } + ] + } + } + }); + await loadAuthConfig(config, logger); + + const defaultResult = await getAuthToken(); + expect(defaultResult).toEqual('Token token=some-api-token'); + const accountResult = await getAuthToken('test2'); + expect(accountResult).toEqual('Token token=some-api-token'); + const noResult = await getAuthToken('test1'); + expect(noResult).toEqual(''); + }); + }); +}); diff --git a/src/auth/auth.ts b/src/auth/auth.ts index 567a51a..62e442b 100644 --- a/src/auth/auth.ts +++ b/src/auth/auth.ts @@ -16,54 +16,34 @@ type Auth = { let authPersistence: Auth; let isLegacyConfig = false; -export async function getAuthToken(accountId? : string): Promise { +async function checkForOAuthToken(tokenId: string): Promise { + if (authPersistence.accountTokens[tokenId]?.authToken !== '' && + authPersistence.accountTokens[tokenId]?.authToken.includes('Bearer')) { + if (authPersistence.accountTokens[tokenId].authTokenExpiryDate > Date.now()) { + return true + } + authPersistence.logger.info('OAuth token expired, renewing'); + await loadAuthConfig(authPersistence.config, authPersistence.logger); + return authPersistence.accountTokens[tokenId].authTokenExpiryDate > Date.now() + } + return false +} +export async function getAuthToken(accountId? : string): Promise { // if authPersistence is not initialized, load the auth config if (!authPersistence?.accountTokens) { await loadAuthConfig(authPersistence.config, authPersistence.logger); } - if(isLegacyConfig){ - if ( - (authPersistence.accountTokens.default.authToken !== '' && - authPersistence.accountTokens.default.authToken.includes('Bearer') && - authPersistence.accountTokens.default.authTokenExpiryDate > Date.now()) // case where OAuth token is still valid - || - (authPersistence.accountTokens.default.authToken !== '' && - authPersistence.accountTokens.default.authToken.includes('Token'))) { // case where API token is used - + if (isLegacyConfig && authPersistence.accountTokens.default.authToken !== '' + && (await checkForOAuthToken('default') || authPersistence.accountTokens.default.authToken.includes('Token'))) { return authPersistence.accountTokens.default.authToken; - } } - else { - // check if accountId is provided - if (accountId && accountId !== '') { - if ( - (authPersistence.accountTokens[accountId].authToken !== '' && - authPersistence.accountTokens[accountId].authToken.includes('Bearer') && - authPersistence.accountTokens[accountId].authTokenExpiryDate > Date.now()) // case where OAuth token is still valid - || - (authPersistence.accountTokens[accountId].authToken !== '' && - authPersistence.accountTokens[accountId].authToken.includes('Token'))) { // case where API token is used - - return authPersistence.accountTokens[accountId].authToken; - } - } - - else { // return default account token if accountId is not provided - const defaultFallback = authPersistence.defaultAccount ?? ""; + const key = accountId && accountId !== '' ? accountId : authPersistence.defaultAccount ?? ''; - if ( - (authPersistence.accountTokens[defaultFallback].authToken !== '' && - authPersistence.accountTokens[defaultFallback].authToken.includes('Bearer') && - authPersistence.accountTokens[defaultFallback].authTokenExpiryDate > Date.now()) // case where OAuth token is still valid - || - (authPersistence.accountTokens[defaultFallback].authToken !== '' && - authPersistence.accountTokens[defaultFallback].authToken.includes('Token'))) { // case where API token is used - - return authPersistence.accountTokens[defaultFallback].authToken; - } - } + if (authPersistence.accountTokens[key]?.authToken !== '' + && (await checkForOAuthToken(key) || authPersistence.accountTokens[key]?.authToken.includes('Token'))) { + return authPersistence.accountTokens[key].authToken; } return ''; @@ -119,7 +99,7 @@ export async function loadAuthConfig(config : RootConfigService, logger: LoggerS else { // new accounts config is present logger.info('New PagerDuty accounts configuration found in config file.'); isLegacyConfig = false; - const accounts = config.getOptional('pagerDuty.accounts'); + const accounts = config.getOptional('pagerDuty.accounts') || []; if(accounts && accounts?.length === 1){ @@ -127,7 +107,7 @@ export async function loadAuthConfig(config : RootConfigService, logger: LoggerS authPersistence.defaultAccount = accounts[0].id; } - accounts?.forEach(async account => { + await Promise.all(accounts.map(async account => { const maskedAccountId = maskString(account.id); if(account.isDefault && !authPersistence.defaultAccount){ @@ -161,7 +141,7 @@ export async function loadAuthConfig(config : RootConfigService, logger: LoggerS logger.info(`PagerDuty API token loaded successfully for account ${maskedAccountId}.`); } - }); + })); if(!authPersistence.defaultAccount){ logger.error('No default account found in config file. One account must be marked as default.');