Skip to content

Commit

Permalink
more tests and refactor oauth1 code to be consistent with oauth2
Browse files Browse the repository at this point in the history
  • Loading branch information
netroy committed Nov 12, 2024
1 parent 7aed26f commit a49f61d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Container from 'typedi';
import { Time } from '@/constants';
import { OAuth1CredentialController } from '@/controllers/oauth/oauth1-credential.controller';
import { CredentialsHelper } from '@/credentials-helper';
import { CredentialsEntity } from '@/databases/entities/credentials-entity';
import type { CredentialsEntity } from '@/databases/entities/credentials-entity';
import type { User } from '@/databases/entities/user';
import { CredentialsRepository } from '@/databases/repositories/credentials.repository';
import { SharedCredentialsRepository } from '@/databases/repositories/shared-credentials.repository';
Expand Down Expand Up @@ -81,6 +81,7 @@ describe('OAuth1CredentialController', () => {
credentialsHelper.applyDefaultsAndOverwrites.mockReturnValueOnce({
requestTokenUrl: 'https://example.domain/oauth/request_token',
authUrl: 'https://example.domain/oauth/authorize',
accessTokenUrl: 'https://example.domain/oauth/access_token',
signatureMethod: 'HMAC-SHA1',
});
nock('https://example.domain')
Expand Down Expand Up @@ -164,15 +165,15 @@ describe('OAuth1CredentialController', () => {

expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'OAuth1 callback failed because of insufficient permissions',
message: 'OAuth callback failed because of insufficient permissions',
},
});
expect(credentialsRepository.findOneBy).toHaveBeenCalledTimes(1);
expect(credentialsRepository.findOneBy).toHaveBeenCalledWith({ id: '1' });
});

it('should render the error page when state differs from the stored state in the credential', async () => {
credentialsRepository.findOneBy.mockResolvedValue(new CredentialsEntity());
credentialsRepository.findOneBy.mockResolvedValue(credential);
credentialsHelper.getDecrypted.mockResolvedValue({ csrfSecret: 'invalid' });

const req = mock<OAuthRequest.OAuth1Credential.Callback>();
Expand All @@ -187,13 +188,13 @@ describe('OAuth1CredentialController', () => {

expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'The OAuth1 callback state is invalid!',
message: 'The OAuth callback state is invalid!',
},
});
});

it('should render the error page when state is older than 5 minutes', async () => {
credentialsRepository.findOneBy.mockResolvedValue(new CredentialsEntity());
credentialsRepository.findOneBy.mockResolvedValue(credential);
credentialsHelper.getDecrypted.mockResolvedValue({ csrfSecret });
jest.spyOn(Csrf.prototype, 'verify').mockReturnValueOnce(true);

Expand All @@ -211,9 +212,52 @@ describe('OAuth1CredentialController', () => {

expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'The OAuth1 state expired. Please try again.',
message: 'The OAuth callback state is invalid!',
},
});
});

it('should exchange the code for a valid token, and save it to DB', async () => {
credentialsRepository.findOneBy.mockResolvedValue(credential);
credentialsHelper.getDecrypted.mockResolvedValue({ csrfSecret });
credentialsHelper.applyDefaultsAndOverwrites.mockReturnValueOnce({
requestTokenUrl: 'https://example.domain/oauth/request_token',
accessTokenUrl: 'https://example.domain/oauth/access_token',
signatureMethod: 'HMAC-SHA1',
});
jest.spyOn(Csrf.prototype, 'verify').mockReturnValueOnce(true);
nock('https://example.domain')
.post('/oauth/access_token', {
oauth_token: 'token',
oauth_verifier: 'verifier',
})
.once()
.reply(200, 'access_token=new_token');
cipher.encrypt.mockReturnValue('encrypted');

const req = mock<OAuthRequest.OAuth1Credential.Callback>();
const res = mock<Response>();
req.query = {
oauth_verifier: 'verifier',
oauth_token: 'token',
state: validState,
} as OAuthRequest.OAuth1Credential.Callback['query'];

await controller.handleCallback(req, res);

expect(cipher.encrypt).toHaveBeenCalledWith({
oauthTokenData: { access_token: 'new_token' },
});
expect(credentialsRepository.update).toHaveBeenCalledWith(
'1',
expect.objectContaining({
data: 'encrypted',
id: '1',
name: 'Test Credential',
type: 'oAuth1Api',
}),
);
expect(res.render).toHaveBeenCalledWith('oauth-callback');
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ describe('OAuth2CredentialController', () => {
beforeEach(() => {
jest.setSystemTime(new Date(timestamp));
jest.resetAllMocks();

credentialsHelper.applyDefaultsAndOverwrites.mockReturnValue({
clientId: 'test-client-id',
clientSecret: 'oauth-secret',
authUrl: 'https://example.domain/o/oauth2/v2/auth',
accessTokenUrl: 'https://example.domain/token',
});
});

describe('getAuthUri', () => {
Expand All @@ -78,10 +85,6 @@ describe('OAuth2CredentialController', () => {
jest.spyOn(Csrf.prototype, 'create').mockReturnValueOnce('token');
sharedCredentialsRepository.findCredentialForUser.mockResolvedValueOnce(credential);
credentialsHelper.getDecrypted.mockResolvedValueOnce({});
credentialsHelper.applyDefaultsAndOverwrites.mockReturnValue({
clientId: 'test-client-id',
authUrl: 'https://example.domain/o/oauth2/v2/auth',
});
cipher.encrypt.mockReturnValue('encrypted');

const req = mock<OAuthRequest.OAuth2Credential.Auth>({ user, query: { id: '1' } });
Expand Down Expand Up @@ -159,7 +162,7 @@ describe('OAuth2CredentialController', () => {

expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'OAuth2 callback failed because of insufficient permissions',
message: 'OAuth callback failed because of insufficient permissions',
},
});
expect(credentialsRepository.findOneBy).toHaveBeenCalledTimes(1);
Expand All @@ -178,7 +181,7 @@ describe('OAuth2CredentialController', () => {
await controller.handleCallback(req, res);
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'The OAuth2 callback state is invalid!',
message: 'The OAuth callback state is invalid!',
},
});
expect(externalHooks.run).not.toHaveBeenCalled();
Expand All @@ -200,20 +203,43 @@ describe('OAuth2CredentialController', () => {

expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'The OAuth2 state expired. Please try again.',
message: 'The OAuth callback state is invalid!',
},
});
expect(externalHooks.run).not.toHaveBeenCalled();
});

it('should exchange the code for a valid token, and save it to DB', async () => {
it('should render the error page when code exchange fails', async () => {
credentialsRepository.findOneBy.mockResolvedValueOnce(credential);
credentialsHelper.getDecrypted.mockResolvedValueOnce({ csrfSecret });
credentialsHelper.applyDefaultsAndOverwrites.mockReturnValue({
clientId: 'test-client-id',
clientSecret: 'oauth-secret',
accessTokenUrl: 'https://example.domain/token',
jest.spyOn(Csrf.prototype, 'verify').mockReturnValueOnce(true);
nock('https://example.domain')
.post(
'/token',
'code=code&grant_type=authorization_code&redirect_uri=http%3A%2F%2Flocalhost%3A5678%2Frest%2Foauth2-credential%2Fcallback',
)
.reply(403, { error: 'Code could not be exchanged' });

const req = mock<OAuthRequest.OAuth2Credential.Callback>({
query: { code: 'code', state: validState },
originalUrl: '?code=code',
});
const res = mock<Response>();

await controller.handleCallback(req, res);

expect(externalHooks.run).toHaveBeenCalled();
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'Code could not be exchanged',
reason: '{"error":"Code could not be exchanged"}',
},
});
});

it('should exchange the code for a valid token, and save it to DB', async () => {
credentialsRepository.findOneBy.mockResolvedValueOnce(credential);
credentialsHelper.getDecrypted.mockResolvedValueOnce({ csrfSecret });
jest.spyOn(Csrf.prototype, 'verify').mockReturnValueOnce(true);
nock('https://example.domain')
.post(
Expand Down
49 changes: 41 additions & 8 deletions packages/cli/src/controllers/oauth/abstract-oauth.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@ import type { ICredentialDataDecryptedObject, IWorkflowExecuteAdditionalData } f
import { jsonParse, ApplicationError } from 'n8n-workflow';
import { Service } from 'typedi';

import { RESPONSE_ERROR_MESSAGES } from '@/constants';
import { RESPONSE_ERROR_MESSAGES, Time } from '@/constants';
import { CredentialsHelper } from '@/credentials-helper';
import type { CredentialsEntity } from '@/databases/entities/credentials-entity';
import { CredentialsRepository } from '@/databases/repositories/credentials.repository';
import { SharedCredentialsRepository } from '@/databases/repositories/shared-credentials.repository';
import { AuthError } from '@/errors/response-errors/auth.error';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';
import { NotFoundError } from '@/errors/response-errors/not-found.error';
import { ExternalHooks } from '@/external-hooks';
import type { ICredentialsDb } from '@/interfaces';
import { Logger } from '@/logging/logger.service';
import type { OAuthRequest } from '@/requests';
import type { AuthenticatedRequest, OAuthRequest } from '@/requests';
import { UrlService } from '@/services/url.service';
import * as WorkflowExecuteAdditionalData from '@/workflow-execute-additional-data';

export interface CsrfStateParam {
type CsrfStateParam = {
cid: string;
token: string;
createdAt: number;
userId?: string;
}
};

// TODO: Flip this flag in v2
// https://linear.app/n8n/issue/CAT-329
Expand Down Expand Up @@ -136,23 +137,55 @@ export abstract class AbstractOAuthController {
return [csrfSecret, Buffer.from(JSON.stringify(state)).toString('base64')];
}

protected decodeCsrfState(encodedState: string): CsrfStateParam {
protected decodeCsrfState(encodedState: string, req: AuthenticatedRequest): CsrfStateParam {
const errorMessage = 'Invalid state format';
const decoded = jsonParse<CsrfStateParam>(Buffer.from(encodedState, 'base64').toString(), {
errorMessage,
});

if (typeof decoded.cid !== 'string' || typeof decoded.token !== 'string') {
throw new ApplicationError(errorMessage);
}

if (decoded.userId !== req.user?.id) {
throw new AuthError('Unauthorized');
}

return decoded;
}

protected verifyCsrfState(decrypted: ICredentialDataDecryptedObject, state: CsrfStateParam) {
const token = new Csrf();
return (
decrypted.csrfSecret === undefined ||
!token.verify(decrypted.csrfSecret as string, state.token)
if (decrypted.csrfSecret === undefined) return false;
if (!token.verify(decrypted.csrfSecret as string, state.token)) return false;
if (!state.createdAt || Date.now() - state.createdAt > 5 * Time.minutes.toMilliseconds)
return false;
return true;
}

protected async resolveCredential<T>(
req: OAuthRequest.OAuth1Credential.Callback | OAuthRequest.OAuth2Credential.Callback,
): Promise<[ICredentialsDb, ICredentialDataDecryptedObject, T]> {
const { state: encodedState } = req.query;
const state = this.decodeCsrfState(encodedState, req);
const credential = await this.getCredentialWithoutUser(state.cid);
if (!credential) {
throw new ApplicationError('OAuth callback failed because of insufficient permissions');
}

const additionalData = await this.getAdditionalData();
const decryptedDataOriginal = await this.getDecryptedData(credential, additionalData);
const oauthCredentials = this.applyDefaultsAndOverwrites<T>(
credential,
decryptedDataOriginal,
additionalData,
);

if (!this.verifyCsrfState(decryptedDataOriginal, state)) {
throw new ApplicationError('The OAuth callback state is invalid!');
}

return [credential, decryptedDataOriginal, oauthCredentials];
}

protected renderCallbackError(res: Response, message: string, reason?: string) {
Expand Down
Loading

0 comments on commit a49f61d

Please sign in to comment.