Skip to content

Commit

Permalink
Merge pull request #223 from biomage-ltd/compute-clusters-pipeline
Browse files Browse the repository at this point in the history
Add authorization on the pipeline parameters
  • Loading branch information
marcellp authored Sep 13, 2021
2 parents a651219 + 370a1da commit dfda76f
Show file tree
Hide file tree
Showing 11 changed files with 16,307 additions and 254 deletions.
16,435 changes: 16,198 additions & 237 deletions package-lock.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"node-fetch": "^2.6.1",
"node-jq": "^2.0.0",
"object-hash": "^2.0.3",
"promise.any": "^2.0.2",
"sns-validator": "^0.3.4",
"socket.io": "^3.1.2",
"socket.io-redis": "^6.0.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ const createTask = (taskName, context) => {
return task;
};

const getQCParams = (task, stepArgs) => {
const getQCParams = (task, context, stepArgs) => {
const { perSample, uploadCountMatrix } = stepArgs;
return {
...task,
...perSample ? { 'sampleUuid.$': '$.sampleUuid' } : { sampleUuid: '' },
...uploadCountMatrix ? { uploadCountMatrix: true } : { uploadCountMatrix: false },
authJWT: context.authJWT,
};
};

Expand All @@ -45,7 +46,7 @@ const buildParams = (task, context, stepArgs) => {
let processParams;

if (task.processName === QC_PROCESS_NAME) {
processParams = getQCParams(task, stepArgs);
processParams = getQCParams(task, context, stepArgs);
} else if (task.processName === GEM2S_PROCESS_NAME) {
processParams = getGem2SParams(task, context);
}
Expand All @@ -62,7 +63,6 @@ const createNewStep = (context, step, stepArgs) => {
const task = createTask(taskName, context);
const params = buildParams(task, context, stepArgs);


return {
...step,
Type: 'Task',
Expand Down
3 changes: 2 additions & 1 deletion src/api/general-services/pipeline-manage/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ const buildStateMachineDefinition = (skeleton, context) => {
return stateMachine;
};

const createQCPipeline = async (experimentId, processingConfigUpdates) => {
const createQCPipeline = async (experimentId, processingConfigUpdates, authJWT) => {
const accountId = await config.awsAccountIdPromise;
const roleArn = `arn:aws:iam::${accountId}:role/state-machine-role-${config.clusterEnv}`;
logger.log(`Fetching processing settings for ${experimentId}`);
Expand Down Expand Up @@ -229,6 +229,7 @@ const createQCPipeline = async (experimentId, processingConfigUpdates) => {
pipelineArtifacts: await getPipelineArtifacts(),
clusterInfo: await getClusterInfo(),
processingConfig: mergedProcessingConfig,
authJWT,
};


Expand Down
20 changes: 13 additions & 7 deletions src/api/route-services/gem2s.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Gem2sService {
io.sockets.emit(`ExperimentUpdates-${experimentId}`, response);
}

static async generateGem2sParams(experimentId) {
static async generateGem2sParams(experimentId, authJWT) {
const experiment = await (new ExperimentService()).getExperimentData(experimentId);
const { samples } = await (new SamplesService()).getSamplesByExperimentId(experimentId);
const {
Expand All @@ -66,6 +66,7 @@ class Gem2sService {
input: { type: experiment.meta.type },
sampleIds: experiment.sampleIds,
sampleNames: experiment.sampleIds.map((sampleId) => samples[sampleId].name),
authJWT,
};

if (metadataKeys.length) {
Expand Down Expand Up @@ -115,8 +116,8 @@ class Gem2sService {
return gem2sStatus !== RUNNING;
}

static async gem2sCreate(experimentId) {
const { taskParams, hashParams } = await this.generateGem2sParams(experimentId);
static async gem2sCreate(experimentId, authJWT) {
const { taskParams, hashParams } = await this.generateGem2sParams(experimentId, authJWT);

const paramsHash = crypto
.createHash('sha1')
Expand Down Expand Up @@ -149,17 +150,22 @@ class Gem2sService {
// Fail hard if there was an error.
await validateRequest(message, 'GEM2SResponse.v1.yaml');

const messageForClient = _.cloneDeep(message);

const {
experimentId, taskName, item,
} = messageForClient;
experimentId, taskName, item, authJWT,
} = message;

await pipelineHook.run(taskName, {
experimentId,
item,
authJWT,
});


const messageForClient = _.cloneDeep(message);

// Make sure authJWT doesn't get back to the client
delete messageForClient.authJWT;

await this.sendUpdateToSubscribed(experimentId, messageForClient, io);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/api/routes/gem2s.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ module.exports = {
async (req, res, next) => {
const { experimentId } = req.params;

Gem2sService.gem2sCreate(experimentId).then((response) => res.json(response)).catch(next);
Gem2sService.gem2sCreate(experimentId, req.headers.authorization)
.then((response) => res.json(response))
.catch(next);
},
],

Expand Down
7 changes: 6 additions & 1 deletion src/api/routes/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ module.exports = {
async (req, res) => {
const { processingConfig } = req.body;

const data = await createQCPipeline(req.params.experimentId, processingConfig || []);
const data = await createQCPipeline(
req.params.experimentId,
processingConfig || [],
req.headers.authorization,
);

const experimentService = new ExperimentService();
await experimentService.saveQCHandle(req.params.experimentId, data);
res.json(data);
Expand Down
5 changes: 3 additions & 2 deletions src/loaders/express.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ const http = require('http');
const AWSXRay = require('aws-xray-sdk');
const _ = require('lodash');
const config = require('../config');
const { authenticationMiddlewareExpress } = require('../utils/authMiddlewares');

const { authenticationMiddlewareExpress, checkAuthExpiredMiddleware } = require('../utils/authMiddlewares');

module.exports = async (app) => {
// Useful if you're behind a reverse proxy (Heroku, Bluemix, AWS ELB, Nginx, etc)
Expand Down Expand Up @@ -109,6 +108,8 @@ module.exports = async (app) => {

app.use(authMw);

app.use(checkAuthExpiredMiddleware);

app.use(OpenApiValidator.middleware({
apiSpec: path.join(__dirname, '..', 'specs', 'api.yaml'),
validateRequests: true,
Expand Down
5 changes: 5 additions & 0 deletions src/utils/__mocks__/authMiddlewares.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ const expressAuthenticationOnlyMiddleware = async (req, res, next) => {
next();
};

const checkAuthExpiredMiddleware = async (req, res, next) => {
next();
};

const authenticationMiddlewareSocketIO = async () => true;
const authorize = async () => true;

Expand All @@ -17,5 +21,6 @@ module.exports = {
authenticationMiddlewareSocketIO,
expressAuthorizationMiddleware,
expressAuthenticationOnlyMiddleware,
checkAuthExpiredMiddleware,
authorize,
};
71 changes: 71 additions & 0 deletions src/utils/authMiddlewares.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// See details at https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
// for how JWT verification works with Cognito.
const promiseAny = require('promise.any');

const AWSXRay = require('aws-xray-sdk');
const fetch = require('node-fetch');
const jwt = require('jsonwebtoken');
const jwtExpress = require('express-jwt');
const jwkToPem = require('jwk-to-pem');
const util = require('util');
const dns = require('dns').promises;

const config = require('../config');

Expand Down Expand Up @@ -41,6 +43,7 @@ const authenticationMiddlewareExpress = async (app) => {
// so during authorization we can check if the user parameter is actually present.
// If not, the user was not authenticated.
credentialsRequired: false,
ignoreExpiration: true,
// JWT tokens are JSON files that are signed using a key.
// We need to make sure that the issuer in the token is correct:
// we verify that the signed JWT includes our own user pool.
Expand Down Expand Up @@ -79,6 +82,73 @@ const authenticationMiddlewareExpress = async (app) => {
});
};


// eslint-disable-next-line no-useless-escape
const INTERNAL_DOMAINS_REGEX = new RegExp('((\.compute\.internal)|(\.svc\.local))$');

const checkAuthExpiredMiddleware = (req, res, next) => {
const isReqFromLocalhost = async () => {
const ip = req.connection.remoteAddress;
const host = req.get('host');

if (ip === '127.0.0.1' || ip === '::ffff:127.0.0.1' || ip === '::1' || host.indexOf('localhost') !== -1) {
return true;
}

throw new Error('ip address is not localhost');
};

const isReqFromCluster = async () => {
const domains = await dns.reverse(req.ip);

if (!domains.some((domain) => INTERNAL_DOMAINS_REGEX.test(domain))) {
throw new Error('ip address does not come from internal sources');
}

return true;
};

if (!req.user) {
return next();
}

// JWT `exp` returns seconds since UNIX epoch, conver to milliseconds for this
const timeLeft = (req.user.exp * 1000) - Date.now();

// ignore if JWT is still valid
if (timeLeft > 0) {
return next();
}

// send error if JWT is older than the limit
if (timeLeft < -(7 * 1000 * 60 * 60)) {
return next(new UnauthenticatedError('token has expired'));
}

// check if we should ignore expired jwt token for this path and request type
const longTimeoutEndpoints = [{ urlMatcher: /^\/v1\/experiments\/.{32}\/cellSets$/, method: 'PATCH' }];
const isEndpointIgnored = longTimeoutEndpoints.some(
({ urlMatcher, method }) => (
req.method.toLowerCase() === method.toLowerCase() && urlMatcher.test(req.url)
),
);

// if endpoint is not in ignore list, the JWT is too old, send an error accordingly
if (!isEndpointIgnored) {
return next(new UnauthenticatedError('token has expired'));
}

promiseAny([isReqFromCluster(), isReqFromLocalhost()])
.then(() => {
next();
})
.catch(() => {
next(new UnauthenticatedError('token has expired'));
});

return null;
};

/**
* Authentication middleware for Socket.IO requests. Resolves with
* the JWT claim if the authentication was successful, or rejects with
Expand Down Expand Up @@ -195,5 +265,6 @@ module.exports = {
authenticationMiddlewareSocketIO,
expressAuthorizationMiddleware,
expressAuthenticationOnlyMiddleware,
checkAuthExpiredMiddleware,
authorize,
};
4 changes: 2 additions & 2 deletions src/utils/hooks/runQCPipeline.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ const { createQCPipeline } = require('../../api/general-services/pipeline-manage
const experimentService = new ExperimentService();

const runQCPipeline = async (payload) => {
const { experimentId } = payload;
const { experimentId, authJWT } = payload;

const qcHandle = await createQCPipeline(experimentId);
const qcHandle = await createQCPipeline(experimentId, [], authJWT);
await experimentService.saveQCHandle(experimentId, qcHandle);
};

Expand Down

0 comments on commit dfda76f

Please sign in to comment.