From 8b1f3edfb6d1bc046012591d87dd358d125098ad Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Wed, 12 Jun 2019 13:32:08 +0100 Subject: [PATCH] feat(stepfunctions): add support for AmazonSageMaker APIs (#2808) Have updated the package `@aws-cdk/aws-stepfunctions-tasks` to include support for SageMaker APIs as per documentation here: https://docs.aws.amazon.com/step-functions/latest/dg/connect-sagemaker.html Includes support for the following Amazon SageMaker API calls: * `CreateTrainingJob` * `CreateTransformJob` Partially remediates #1314 --- .../aws-stepfunctions-tasks/lib/index.ts | 5 +- .../lib/sagemaker-task-base-types.ts | 473 ++++++++++++++++++ .../lib/sagemaker-train-task.ts | 285 +++++++++++ .../lib/sagemaker-transform-task.ts | 226 +++++++++ .../aws-stepfunctions-tasks/package-lock.json | 47 ++ .../aws-stepfunctions-tasks/package.json | 2 + .../test/sagemaker-training-job.test.ts | 303 +++++++++++ .../test/sagemaker-transform-job.test.ts | 190 +++++++ packages/@aws-cdk/aws-stepfunctions/README.md | 30 ++ 9 files changed, 1560 insertions(+), 1 deletion(-) create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 0decc8f601c18..41f2533ba0149 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -5,4 +5,7 @@ export * from './run-ecs-task-base-types'; export * from './publish-to-topic'; export * from './send-to-queue'; export * from './run-ecs-ec2-task'; -export * from './run-ecs-fargate-task'; \ No newline at end of file +export * from './run-ecs-fargate-task'; +export * from './sagemaker-task-base-types'; +export * from './sagemaker-train-task'; +export * from './sagemaker-transform-task'; \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts new file mode 100644 index 0000000000000..0a8e6bf365795 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -0,0 +1,473 @@ +import ec2 = require('@aws-cdk/aws-ec2'); +import kms = require('@aws-cdk/aws-kms'); + +// +// Create Training Job types +// + +/** + * @experimental + */ +export interface AlgorithmSpecification { + + /** + * Name of the algorithm resource to use for the training job. + */ + readonly algorithmName?: string; + + /** + * List of metric definition objects. Each object specifies the metric name and regular expressions used to parse algorithm logs. + */ + readonly metricDefinitions?: MetricDefinition[]; + + /** + * Registry path of the Docker image that contains the training algorithm. + */ + readonly trainingImage?: string; + + /** + * Input mode that the algorithm supports. + * + * @default is 'File' mode + */ + readonly trainingInputMode?: InputMode; +} + +/** + * Describes the training, validation or test dataset and the Amazon S3 location where it is stored. + * + * @experimental + */ +export interface Channel { + + /** + * Name of the channel + */ + readonly channelName: string; + + /** + * Compression type if training data is compressed + */ + readonly compressionType?: CompressionType; + + /** + * Content type + */ + readonly contentType?: string; + + /** + * Location of the data channel + */ + readonly dataSource: DataSource; + + /** + * Input mode to use for the data channel in a training job. + */ + readonly inputMode?: InputMode; + + /** + * Record wrapper type + */ + readonly recordWrapperType?: RecordWrapperType; + + /** + * Shuffle config option for input data in a channel. + */ + readonly shuffleConfig?: ShuffleConfig; +} + +/** + * Configuration for a shuffle option for input data in a channel. + * + * @experimental + */ +export interface ShuffleConfig { + /** + * Determines the shuffling order. + */ + readonly seed: number; +} + +/** + * Location of the channel data. + * + * @experimental + */ +export interface DataSource { + /** + * S3 location of the data source that is associated with a channel. + */ + readonly s3DataSource: S3DataSource; +} + +/** + * S3 location of the channel data. + * + * @experimental + */ +export interface S3DataSource { + /** + * List of one or more attribute names to use that are found in a specified augmented manifest file. + */ + readonly attributeNames?: string[]; + + /** + * S3 Data Distribution Type + */ + readonly s3DataDistributionType?: S3DataDistributionType; + + /** + * S3 Data Type + */ + readonly s3DataType?: S3DataType; + + /** + * S3 Uri + */ + readonly s3Uri: string; +} + +/** + * @experimental + */ +export interface OutputDataConfig { + /** + * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + */ + readonly encryptionKey?: kms.IKey; + + /** + * Identifies the S3 path where you want Amazon SageMaker to store the model artifacts. + */ + readonly s3OutputPath: string; +} + +export interface StoppingCondition { + /** + * The maximum length of time, in seconds, that the training or compilation job can run. + */ + readonly maxRuntimeInSeconds?: number; +} + +export interface ResourceConfig { + + /** + * The number of ML compute instances to use. + * + * @default 1 instance. + */ + readonly instanceCount: number; + + /** + * ML compute instance type. + * + * @default is the 'm4.xlarge' instance type. + */ + readonly instanceType: ec2.InstanceType; + + /** + * KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. + */ + readonly volumeKmsKeyId?: kms.IKey; + + /** + * Size of the ML storage volume that you want to provision. + * + * @default 10 GB EBS volume. + */ + readonly volumeSizeInGB: number; +} + +/** + * + * @experimental + */ +export interface VpcConfig { + /** + * VPC security groups. + */ + readonly securityGroups: ec2.ISecurityGroup[]; + + /** + * VPC id + */ + readonly vpc: ec2.Vpc; + + /** + * VPC subnets. + */ + readonly subnets: ec2.ISubnet[]; +} + +/** + * Specifies the metric name and regular expressions used to parse algorithm logs. + * + * @experimental + */ +export interface MetricDefinition { + + /** + * Name of the metric. + */ + readonly name: string; + + /** + * Regular expression that searches the output of a training job and gets the value of the metric. + */ + readonly regex: string; +} + +/** + * S3 Data Type. + */ +export enum S3DataType { + /** + * Manifest File Data Type + */ + ManifestFile = 'ManifestFile', + + /** + * S3 Prefix Data Type + */ + S3Prefix = 'S3Prefix', + + /** + * Augmented Manifest File Data Type + */ + AugmentedManifestFile = 'AugmentedManifestFile' +} + +/** + * S3 Data Distribution Type. + */ +export enum S3DataDistributionType { + /** + * Fully replicated S3 Data Distribution Type + */ + FullyReplicated = 'FullyReplicated', + + /** + * Sharded By S3 Key Data Distribution Type + */ + ShardedByS3Key = 'ShardedByS3Key' +} + +/** + * Define the format of the input data. + */ +export enum RecordWrapperType { + /** + * None record wrapper type + */ + None = 'None', + + /** + * RecordIO record wrapper type + */ + RecordIO = 'RecordIO' +} + +/** + * Input mode that the algorithm supports. + */ +export enum InputMode { + /** + * Pipe mode + */ + Pipe = 'Pipe', + + /** + * File mode. + */ + File = 'File' +} + +/** + * Compression type of the data. + */ +export enum CompressionType { + /** + * None compression type + */ + None = 'None', + + /** + * Gzip compression type + */ + Gzip = 'Gzip' +} + +// +// Create Transform Job types +// + +/** + * Dataset to be transformed and the Amazon S3 location where it is stored. + * + * @experimental + */ +export interface TransformInput { + + /** + * The compression type of the transform data. + */ + readonly compressionType?: CompressionType; + + /** + * Multipurpose internet mail extension (MIME) type of the data. + */ + readonly contentType?: string; + + /** + * S3 location of the channel data + */ + readonly transformDataSource: TransformDataSource; + + /** + * Method to use to split the transform job's data files into smaller batches. + */ + readonly splitType?: SplitType; +} + +/** + * S3 location of the input data that the model can consume. + * + * @experimental + */ +export interface TransformDataSource { + + /** + * S3 location of the input data + */ + readonly s3DataSource: TransformS3DataSource; +} + +/** + * Location of the channel data. + * + * @experimental + */ +export interface TransformS3DataSource { + + /** + * S3 Data Type + * + * @default 'S3Prefix' + */ + readonly s3DataType?: S3DataType; + + /** + * Identifies either a key name prefix or a manifest. + */ + readonly s3Uri: string; +} + +/** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + * + * @experimental + */ +export interface TransformOutput { + + /** + * MIME type used to specify the output data. + */ + readonly accept?: string; + + /** + * Defines how to assemble the results of the transform job as a single S3 object. + */ + readonly assembleWith?: AssembleWith; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + */ + readonly encryptionKey?: kms.Key; + + /** + * S3 path where you want Amazon SageMaker to store the results of the transform job. + */ + readonly s3OutputPath: string; +} + +/** + * ML compute instances for the transform job. + * + * @experimental + */ +export interface TransformResources { + + /** + * Nmber of ML compute instances to use in the transform job + */ + readonly instanceCount: number; + + /** + * ML compute instance type for the transform job. + */ + readonly instanceType: ec2.InstanceType; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s). + */ + readonly volumeKmsKeyId?: kms.Key; +} + +/** + * Specifies the number of records to include in a mini-batch for an HTTP inference request. + */ +export enum BatchStrategy { + + /** + * Fits multiple records in a mini-batch. + */ + MultiRecord = 'MultiRecord', + + /** + * Use a single record when making an invocation request. + */ + SingleRecord = 'SingleRecord' +} + +/** + * Method to use to split the transform job's data files into smaller batches. + */ +export enum SplitType { + + /** + * Input data files are not split, + */ + None = 'None', + + /** + * Split records on a newline character boundary. + */ + Line = 'Line', + + /** + * Split using MXNet RecordIO format. + */ + RecordIO = 'RecordIO', + + /** + * Split using TensorFlow TFRecord format. + */ + TFRecord = 'TFRecord' +} + +/** + * How to assemble the results of the transform job as a single S3 object. + */ +export enum AssembleWith { + + /** + * Concatenate the results in binary format. + */ + None = 'None', + + /** + * Add a newline character at the end of every transformed record. + */ + Line = 'Line' + +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts new file mode 100644 index 0000000000000..9d173384cdf0e --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -0,0 +1,285 @@ +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import { Construct, Stack } from '@aws-cdk/cdk'; +import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, + S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types'; + +/** + * @experimental + */ +export interface SagemakerTrainProps { + + /** + * Training Job Name. + */ + readonly trainingJobName: string; + + /** + * Role for thte Training Job. + */ + readonly role?: iam.IRole; + + /** + * Specify if the task is synchronous or asychronous. + * + * @default false + */ + readonly synchronous?: boolean; + + /** + * Identifies the training algorithm to use. + */ + readonly algorithmSpecification: AlgorithmSpecification; + + /** + * Hyperparameters to be used for the train job. + */ + readonly hyperparameters?: {[key: string]: any}; + + /** + * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. + */ + readonly inputDataConfig: Channel[]; + + /** + * Tags to be applied to the train job. + */ + readonly tags?: {[key: string]: any}; + + /** + * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. + */ + readonly outputDataConfig: OutputDataConfig; + + /** + * Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. + */ + readonly resourceConfig?: ResourceConfig; + + /** + * Sets a time limit for training. + */ + readonly stoppingCondition?: StoppingCondition; + + /** + * Specifies the VPC that you want your training job to connect to. + */ + readonly vpcConfig?: VpcConfig; +} + +/** + * Class representing the SageMaker Create Training Job task. + */ +export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask { + + /** + * Allows specify security group connections for instances of this fleet. + */ + public readonly connections: ec2.Connections = new ec2.Connections(); + + /** + * The execution role for the Sagemaker training job. + * + * @default new role for Amazon SageMaker to assume is automatically created. + */ + public readonly role: iam.IRole; + + /** + * The Algorithm Specification + */ + private readonly algorithmSpecification: AlgorithmSpecification; + + /** + * The Input Data Config. + */ + private readonly inputDataConfig: Channel[]; + + /** + * The resource config for the task. + */ + private readonly resourceConfig: ResourceConfig; + + /** + * The resource config for the task. + */ + private readonly stoppingCondition: StoppingCondition; + + constructor(scope: Construct, private readonly props: SagemakerTrainProps) { + + // set the default resource config if not defined. + this.resourceConfig = props.resourceConfig || { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + volumeSizeInGB: 10 + }; + + // set the stopping condition if not defined + this.stoppingCondition = props.stoppingCondition || { + maxRuntimeInSeconds: 3600 + }; + + // set the sagemaker role or create new one + this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', scope).policyArn + ] + }); + + // set the input mode to 'File' if not defined + this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? + ( props.algorithmSpecification ) : + ( { ...props.algorithmSpecification, trainingInputMode: InputMode.File } ); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.inputDataConfig = props.inputDataConfig.map(config => { + if (!config.dataSource.s3DataSource.s3DataType) { + return Object.assign({}, config, { dataSource: { s3DataSource: + { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3Prefix } } }); + } else { + return config; + } + }); + + // add the security groups to the connections object + if (this.props.vpcConfig) { + this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg)); + } + } + + public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.props.synchronous ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private renderParameters(): {[key: string]: any} { + return { + TrainingJobName: this.props.trainingJobName, + RoleArn: this.role.roleArn, + ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), + ...(this.renderInputDataConfig(this.inputDataConfig)), + ...(this.renderOutputDataConfig(this.props.outputDataConfig)), + ...(this.renderResourceConfig(this.resourceConfig)), + ...(this.renderStoppingCondition(this.stoppingCondition)), + ...(this.renderHyperparameters(this.props.hyperparameters)), + ...(this.renderTags(this.props.tags)), + ...(this.renderVpcConfig(this.props.vpcConfig)), + }; + } + + private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { + return { + AlgorithmSpecification: { + TrainingInputMode: spec.trainingInputMode, + ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {}, + ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, + ...(spec.metricDefinitions) ? + { MetricDefinitions: spec.metricDefinitions + .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {} + } + }; + } + + private renderInputDataConfig(config: Channel[]): {[key: string]: any} { + return { + InputDataConfig: config.map(channel => ({ + ChannelName: channel.channelName, + DataSource: { + S3DataSource: { + S3Uri: channel.dataSource.s3DataSource.s3Uri, + S3DataType: channel.dataSource.s3DataSource.s3DataType, + ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? + { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, + ...(channel.dataSource.s3DataSource.attributeNames) ? + { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, + } + }, + ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, + ...(channel.contentType) ? { ContentType: channel.contentType } : {}, + ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, + ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, + })) + }; + } + + private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { + return { + OutputDataConfig: { + S3OutputPath: config.s3OutputPath, + ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + } + }; + } + + private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { + return { + ResourceConfig: { + InstanceCount: config.instanceCount, + InstanceType: 'ml.' + config.instanceType, + VolumeSizeInGB: config.volumeSizeInGB, + ...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {}, + } + }; + } + + private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { + return { + StoppingCondition: { + MaxRuntimeInSeconds: config.maxRuntimeInSeconds + } + }; + } + + private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { + return (params) ? { HyperParameters: params } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { + return (config) ? { VpcConfig: { + SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )), + Subnets: config.subnets.map(subnet => ( subnet.subnetId )), + }} : {}; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = Stack.of(task); + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'training-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.props.synchronous) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule' + }))); + } + + return policyStatements; + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts new file mode 100644 index 0000000000000..bfd365f88a293 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts @@ -0,0 +1,226 @@ +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import { Construct, Stack } from '@aws-cdk/cdk'; +import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; + +/** + * @experimental + */ +export interface SagemakerTransformProps { + + /** + * Training Job Name. + */ + readonly transformJobName: string; + + /** + * Role for thte Training Job. + */ + readonly role?: iam.IRole; + + /** + * Specify if the task is synchronous or asychronous. + */ + readonly synchronous?: boolean; + + /** + * Number of records to include in a mini-batch for an HTTP inference request. + */ + readonly batchStrategy?: BatchStrategy; + + /** + * Environment variables to set in the Docker container. + */ + readonly environment?: {[key: string]: any}; + + /** + * Maximum number of parallel requests that can be sent to each instance in a transform job. + */ + readonly maxConcurrentTransforms?: number; + + /** + * Maximum allowed size of the payload, in MB. + */ + readonly maxPayloadInMB?: number; + + /** + * Name of the model that you want to use for the transform job. + */ + readonly modelName: string; + + /** + * Tags to be applied to the train job. + */ + readonly tags?: {[key: string]: any}; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + readonly transformInput: TransformInput; + + /** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + */ + readonly transformOutput: TransformOutput; + + /** + * ML compute instances for the transform job. + */ + readonly transformResources?: TransformResources; +} + +/** + * Class representing the SageMaker Create Training Job task. + * + * @experimental + */ +export class SagemakerTransformTask implements sfn.IStepFunctionsTask { + + /** + * The execution role for the Sagemaker training job. + * + * @default new role for Amazon SageMaker to assume is automatically created. + */ + public readonly role: iam.IRole; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + private readonly transformInput: TransformInput; + + /** + * ML compute instances for the transform job. + */ + private readonly transformResources: TransformResources; + + constructor(scope: Construct, private readonly props: SagemakerTransformProps) { + + // set the sagemaker role or create new one + this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', scope).policyArn + ] + }); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : + Object.assign({}, props.transformInput, + { transformDataSource: + { s3DataSource: + { ...props.transformInput.transformDataSource.s3DataSource, + s3DataType: S3DataType.S3Prefix + } + } + }); + + // set the default value for the transform resources + this.transformResources = props.transformResources || { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + }; + } + + public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + (this.props.synchronous ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private renderParameters(): {[key: string]: any} { + return { + ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, + ...(this.renderEnvironment(this.props.environment)), + ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, + ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, + ModelName: this.props.modelName, + ...(this.renderTags(this.props.tags)), + ...(this.renderTransformInput(this.transformInput)), + TransformJobName: this.props.transformJobName, + ...(this.renderTransformOutput(this.props.transformOutput)), + ...(this.renderTransformResources(this.transformResources)), + }; + } + + private renderTransformInput(input: TransformInput): {[key: string]: any} { + return { + TransformInput: { + ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, + ...(input.contentType) ? { ContentType: input.contentType } : {}, + DataSource: { + S3DataSource: { + S3Uri: input.transformDataSource.s3DataSource.s3Uri, + S3DataType: input.transformDataSource.s3DataSource.s3DataType, + } + }, + ...(input.splitType) ? { SplitType: input.splitType } : {}, + } + }; + } + + private renderTransformOutput(output: TransformOutput): {[key: string]: any} { + return { + TransformOutput: { + S3OutputPath: output.s3OutputPath, + ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, + ...(output.accept) ? { Accept: output.accept } : {}, + ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + } + }; + } + + private renderTransformResources(resources: TransformResources): {[key: string]: any} { + return { + TransformResources: { + InstanceCount: resources.instanceCount, + InstanceType: 'ml.' + resources.instanceType, + ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, + } + }; + } + + private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { + return (environment) ? { Environment: environment } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = Stack.of(task); + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.props.synchronous) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule' + }))); + } + + return policyStatements; + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json index 2aad8cca5e64d..9fe62e9ec1ff3 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json @@ -4,6 +4,53 @@ "lockfileVersion": 1, "requires": true, "dependencies": { + "@aws-cdk/aws-kms": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/aws-kms/-/aws-kms-0.33.0.tgz", + "integrity": "sha512-Yj1i/kqcpLu4LMIfIrk8F1Znereh7kL05+j7Ho0gy+HjwMbODIxB0BcyTiJ/5CjHkJPsR8Tl2GPyerZ6OGk/Dw==", + "requires": { + "@aws-cdk/aws-iam": "^0.33.0", + "@aws-cdk/cdk": "^0.33.0" + }, + "dependencies": { + "@aws-cdk/aws-iam": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/aws-iam/-/aws-iam-0.33.0.tgz", + "integrity": "sha512-d6HVkScJlG3a0rwWO0LgmZCTndze1c2cpoIezJINZ+sXPyMQiWWyFQDVTDC3LxPPUalG9t42gr2139d2zbfX6w==", + "requires": { + "@aws-cdk/cdk": "^0.33.0", + "@aws-cdk/region-info": "^0.33.0" + } + }, + "@aws-cdk/cdk": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/cdk/-/cdk-0.33.0.tgz", + "integrity": "sha512-ARfTC6ZTg1r2FOWntYo4kZ3S/Fju2vAagQavll56BJ3EPCxfYbPnIAWu3oFiSzg/4XQ345tbAZP1GSVZsF4RJw==", + "requires": { + "@aws-cdk/cx-api": "^0.33.0" + } + } + } + }, + "@aws-cdk/cx-api": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/cx-api/-/cx-api-0.33.0.tgz", + "integrity": "sha512-PvPO1quhrezUyYtyi3kEq4CHjmg5TccWQrU4khmTrP9bmb7sNKCmR7ish1VHcA2FBaNjtAj0PgdA+2/+Q+Pzrw==", + "requires": { + "semver": "^6.0.0" + }, + "dependencies": { + "semver": { + "version": "6.1.0", + "bundled": true + } + } + }, + "@aws-cdk/region-info": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/region-info/-/region-info-0.33.0.tgz", + "integrity": "sha512-Sy0gXDqzGNuOYAF7edd5rlY3iChVSfjaaZ+bONyClF7gulkYv4jehYkQ1ShATl8XsVRedtCOwSU+mDo/tu8npA==" + }, "@babel/code-frame": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.0.0.tgz", diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json index 301a38e1808e6..e300ba03f1ef3 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json @@ -82,6 +82,7 @@ "@aws-cdk/aws-ec2": "^0.34.0", "@aws-cdk/aws-ecs": "^0.34.0", "@aws-cdk/aws-iam": "^0.34.0", + "@aws-cdk/aws-kms": "^0.34.0", "@aws-cdk/aws-lambda": "^0.34.0", "@aws-cdk/aws-sns": "^0.34.0", "@aws-cdk/aws-sqs": "^0.34.0", @@ -94,6 +95,7 @@ "@aws-cdk/aws-ec2": "^0.34.0", "@aws-cdk/aws-ecs": "^0.34.0", "@aws-cdk/aws-iam": "^0.34.0", + "@aws-cdk/aws-kms": "^0.34.0", "@aws-cdk/aws-lambda": "^0.34.0", "@aws-cdk/aws-sns": "^0.34.0", "@aws-cdk/aws-sqs": "^0.34.0", diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts new file mode 100644 index 0000000000000..dd8de65a04552 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -0,0 +1,303 @@ +import '@aws-cdk/assert/jest'; +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import kms = require('@aws-cdk/aws-kms'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import cdk = require('@aws-cdk/cdk'); +import tasks = require('../lib'); + +let stack: cdk.Stack; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); + }); + +test('create basic training job', () => { + // WHEN + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + trainingJobName: "MyTrainJob", + algorithmSpecification: { + algorithmName: "BlazingText", + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3Uri: "s3://mybucket/mytrainpath" + } + } + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath' + }, + })}); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob', + End: true, + Parameters: { + AlgorithmSpecification: { + AlgorithmName: 'BlazingText', + TrainingInputMode: 'File', + }, + InputDataConfig: [ + { + ChannelName: 'train', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytrainpath' + } + } + } + ], + OutputDataConfig: { + S3OutputPath: 's3://mybucket/myoutputpath' + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.m4.xlarge', + VolumeSizeInGB: 10 + }, + RoleArn: { "Fn::GetAtt": [ "SagemakerRole5FDB64E1", "Arn" ] }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600 + }, + TrainingJobName: 'MyTrainJob', + }, + }); +}); + +test('create complex training job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const vpc = new ec2.Vpc(stack, "VPC"); + const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); + securityGroup.addIngressRule(new ec2.AnyIPv4(), new ec2.TcpPort(22), 'allow ssh access from the world'); + + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + trainingJobName: "MyTrainJob", + synchronous: true, + role, + algorithmSpecification: { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File, + metricDefinitions: [ + { + name: 'mymetric', regex: 'regex_pattern' + } + ] + }, + hyperparameters: { + lr: "0.1" + }, + inputDataConfig: [ + { + channelName: "train", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.None, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytrainpath", + } + } + }, + { + channelName: "test", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.Gzip, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytestpath", + } + } + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath', + encryptionKey: kmsKey + }, + resourceConfig: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50, + volumeKmsKeyId: kmsKey, + }, + stoppingCondition: { + maxRuntimeInSeconds: 3600 + }, + tags: { + Project: "MyProject" + }, + vpcConfig: { + vpc, + subnets: vpc.privateSubnets, + securityGroups: [ securityGroup ] + } + })}); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob.sync', + End: true, + Parameters: { + TrainingJobName: 'MyTrainJob', + RoleArn: { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, + AlgorithmSpecification: { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + MetricDefinitions: [ + { Name: "mymetric", Regex: "regex_pattern" } + ] + }, + HyperParameters: { + lr: "0.1" + }, + InputDataConfig: [ + { + ChannelName: 'train', + CompressionType: 'None', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytrainpath' + } + } + }, + { + ChannelName: 'test', + CompressionType: 'Gzip', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytestpath' + } + } + } + ], + OutputDataConfig: { + S3OutputPath: 's3://mybucket/myoutputpath', + KmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50, + VolumeKmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600 + }, + Tags: [ + { Key: "Project", Value: "MyProject" } + ], + VpcConfig: { + SecurityGroupIds: [ { "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] } ], + Subnets: [ + { Ref: "VPCPrivateSubnet1Subnet8BCA10E0" }, + { Ref: "VPCPrivateSubnet2SubnetCFCDAA7A" }, + { Ref: "VPCPrivateSubnet3Subnet3EDCD457" } + ] + } + }, + }); +}); + +test('pass param to training job', () => { + // WHEN + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + trainingJobName: sfn.Data.stringAt('$.JobName'), + role, + algorithmSpecification: { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: sfn.Data.stringAt('$.S3Bucket') + } + } + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath' + }, + resourceConfig: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }, + stoppingCondition: { + maxRuntimeInSeconds: 3600 + } + })}); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob', + End: true, + Parameters: { + 'TrainingJobName.$': '$.JobName', + 'RoleArn': { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, + 'AlgorithmSpecification': { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + }, + 'InputDataConfig': [ + { + ChannelName: 'train', + DataSource: { + S3DataSource: { + 'S3DataType': 'S3Prefix', + 'S3Uri.$': '$.S3Bucket' + } + } + } + ], + 'OutputDataConfig': { + S3OutputPath: 's3://mybucket/myoutputpath' + }, + 'ResourceConfig': { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50 + }, + 'StoppingCondition': { + MaxRuntimeInSeconds: 3600 + } + }, + }); +}); \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts new file mode 100644 index 0000000000000..e6a25b0f490dc --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -0,0 +1,190 @@ +import '@aws-cdk/assert/jest'; +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import kms = require('@aws-cdk/aws-kms'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import cdk = require('@aws-cdk/cdk'); +import tasks = require('../lib'); +import { BatchStrategy, S3DataType } from '../lib'; + +let stack: cdk.Stack; +let role: iam.Role; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); + role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + }); + +test('create basic transform job', () => { + // WHEN + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + transformJobName: "MyTransformJob", + modelName: "MyModelName", + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + }) }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob', + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.m4.xlarge', + } + }, + }); +}); + +test('create complex transform job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + transformJobName: "MyTransformJob", + modelName: "MyModelName", + synchronous: true, + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + encryptionKey: kmsKey, + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeKmsKeyId: kmsKey, + }, + tags: { + Project: 'MyProject', + }, + batchStrategy: BatchStrategy.MultiRecord, + environment: { + SOMEVAR: 'myvalue' + }, + maxConcurrentTransforms: 3, + maxPayloadInMB: 100, + }) }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob.sync', + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + KmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeKmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + Tags: [ + { Key: 'Project', Value: 'MyProject' } + ], + MaxConcurrentTransforms: 3, + MaxPayloadInMB: 100, + Environment: { + SOMEVAR: 'myvalue' + }, + BatchStrategy: 'MultiRecord' + }, + }); +}); + +test('pass param to transform job', () => { + // WHEN + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + transformJobName: sfn.Data.stringAt('$.TransformJobName'), + modelName: sfn.Data.stringAt('$.ModelName'), + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + } + }) }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob', + End: true, + Parameters: { + 'TransformJobName.$': '$.TransformJobName', + 'ModelName.$': '$.ModelName', + 'TransformInput': { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + 'TransformOutput': { + S3OutputPath: 's3://outputbucket/prefix', + }, + 'TransformResources': { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + } + }, + }); +}); \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions/README.md b/packages/@aws-cdk/aws-stepfunctions/README.md index 0bb4278c64659..1aa984588a4e9 100644 --- a/packages/@aws-cdk/aws-stepfunctions/README.md +++ b/packages/@aws-cdk/aws-stepfunctions/README.md @@ -126,6 +126,8 @@ couple of the tasks available are: * `tasks.SendToQueue` -- send a message to an SQS queue * `tasks.RunEcsFargateTask`/`ecs.RunEcsEc2Task` -- run a container task, depending on the type of capacity. +* `tasks.SagemakerTrainTask` -- run a SageMaker training job +* `tasks.SagemakerTransformTask` -- run a SageMaker transform job #### Task parameters from the state json @@ -249,6 +251,34 @@ const task = new sfn.Task(this, 'CallFargate', { }); ``` +#### SageMaker Transform example + +```ts +const transformJob = new tasks.SagemakerTransformTask( + transformJobName: "MyTransformJob", + modelName: "MyModelName", + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/train', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/TransformJobOutputPath', + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), +}); + +const task = new sfn.Task(this, 'Batch Inference', { + task: transformJob +}); +``` + ### Pass A `Pass` state does no work, but it can optionally transform the execution's