Skip to content

Commit

Permalink
feat(client-sagemaker): This change allows customers to provide a cus…
Browse files Browse the repository at this point in the history
…tom entrypoint script for the docker container to be run while executing training jobs, and provide custom arguments to the entrypoint script.
  • Loading branch information
awstools committed Oct 27, 2022
1 parent 2d06b27 commit 4978352
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 134 deletions.
10 changes: 5 additions & 5 deletions clients/client-sagemaker/src/SageMaker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2313,10 +2313,10 @@ export class SageMaker extends SageMakerClient {
* Amazon SageMaker Studio. For more information, see <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/experiments-view-compare.html#experiments-view">View
* Experiments, Trials, and Trial Components</a>.</p>
* <important>
* <p>Do not include any security-sensitive information including account access
* IDs, secrets or tokens in any hyperparameter field. If the use of
* security-sensitive credentials are detected, SageMaker will reject your training
* job request and return an exception error.</p>
* <p>Do not include any security-sensitive information including account access IDs,
* secrets or tokens in any hyperparameter field. If the use of security-sensitive
* credentials are detected, SageMaker will reject your training job request and return an
* exception error.</p>
* </important>
*/
public createHyperParameterTuningJob(
Expand Down Expand Up @@ -3143,7 +3143,7 @@ export class SageMaker extends SageMakerClient {
* </li>
* <li>
* <p>
* <code>InputDataConfig</code> - Describes the training dataset and the Amazon S3,
* <code>InputDataConfig</code> - Describes the input required by the training job and the Amazon S3,
* EFS, or FSx location where it is stored.</p>
* </li>
* <li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ export interface CreateHyperParameterTuningJobCommandOutput
* Amazon SageMaker Studio. For more information, see <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/experiments-view-compare.html#experiments-view">View
* Experiments, Trials, and Trial Components</a>.</p>
* <important>
* <p>Do not include any security-sensitive information including account access
* IDs, secrets or tokens in any hyperparameter field. If the use of
* security-sensitive credentials are detected, SageMaker will reject your training
* job request and return an exception error.</p>
* <p>Do not include any security-sensitive information including account access IDs,
* secrets or tokens in any hyperparameter field. If the use of security-sensitive
* credentials are detected, SageMaker will reject your training job request and return an
* exception error.</p>
* </important>
* @example
* Use a bare-bones client and the command you need to make an API call.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export interface CreateTrainingJobCommandOutput extends CreateTrainingJobRespons
* </li>
* <li>
* <p>
* <code>InputDataConfig</code> - Describes the training dataset and the Amazon S3,
* <code>InputDataConfig</code> - Describes the input required by the training job and the Amazon S3,
* EFS, or FSx location where it is stored.</p>
* </li>
* <li>
Expand Down
14 changes: 14 additions & 0 deletions clients/client-sagemaker/src/models/models_0.ts
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,20 @@ export interface AlgorithmSpecification {
* </ul>
*/
EnableSageMakerMetricsTimeSeries?: boolean;

/**
* <p>The <a href="https://docs.docker.com/engine/reference/builder/">entrypoint script
* for a Docker container</a> used to run a training job. This script takes
* precedence over the default train processing instructions. See <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-dockerfile.html">How Amazon SageMaker
* Runs Your Training Image</a> for more information.</p>
*/
ContainerEntrypoint?: string[];

/**
* <p>The arguments for a container used to run a training job. See <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-dockerfile.html">How Amazon SageMaker
* Runs Your Training Image</a> for additional information.</p>
*/
ContainerArguments?: string[];
}

export enum AlgorithmStatus {
Expand Down
4 changes: 2 additions & 2 deletions clients/client-sagemaker/src/models/models_1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6855,8 +6855,8 @@ export interface CreateTrainingJobRequest {
* <important>
* <p>Do not include any security-sensitive information including account access IDs,
* secrets or tokens in any hyperparameter field. If the use of security-sensitive
* credentials are detected, SageMaker will reject your training job request and return
* an exception error.</p>
* credentials are detected, SageMaker will reject your training job request and return an
* exception error.</p>
* </important>
*/
HyperParameters?: Record<string, string>;
Expand Down
54 changes: 54 additions & 0 deletions clients/client-sagemaker/src/protocols/Aws_json1_1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15499,6 +15499,12 @@ const serializeAws_json1_1AlarmList = (input: Alarm[], context: __SerdeContext):
const serializeAws_json1_1AlgorithmSpecification = (input: AlgorithmSpecification, context: __SerdeContext): any => {
return {
...(input.AlgorithmName != null && { AlgorithmName: input.AlgorithmName }),
...(input.ContainerArguments != null && {
ContainerArguments: serializeAws_json1_1TrainingContainerArguments(input.ContainerArguments, context),
}),
...(input.ContainerEntrypoint != null && {
ContainerEntrypoint: serializeAws_json1_1TrainingContainerEntrypoint(input.ContainerEntrypoint, context),
}),
...(input.EnableSageMakerMetricsTimeSeries != null && {
EnableSageMakerMetricsTimeSeries: input.EnableSageMakerMetricsTimeSeries,
}),
Expand Down Expand Up @@ -22432,6 +22438,22 @@ const serializeAws_json1_1TrafficRoutingConfig = (input: TrafficRoutingConfig, c
};
};

const serializeAws_json1_1TrainingContainerArguments = (input: string[], context: __SerdeContext): any => {
return input
.filter((e: any) => e != null)
.map((entry) => {
return entry;
});
};

const serializeAws_json1_1TrainingContainerEntrypoint = (input: string[], context: __SerdeContext): any => {
return input
.filter((e: any) => e != null)
.map((entry) => {
return entry;
});
};

const serializeAws_json1_1TrainingEnvironmentMap = (input: Record<string, string>, context: __SerdeContext): any => {
return Object.entries(input).reduce((acc: Record<string, any>, [key, value]: [string, any]) => {
if (value === null) {
Expand Down Expand Up @@ -23344,6 +23366,14 @@ const deserializeAws_json1_1AlarmList = (output: any, context: __SerdeContext):
const deserializeAws_json1_1AlgorithmSpecification = (output: any, context: __SerdeContext): AlgorithmSpecification => {
return {
AlgorithmName: __expectString(output.AlgorithmName),
ContainerArguments:
output.ContainerArguments != null
? deserializeAws_json1_1TrainingContainerArguments(output.ContainerArguments, context)
: undefined,
ContainerEntrypoint:
output.ContainerEntrypoint != null
? deserializeAws_json1_1TrainingContainerEntrypoint(output.ContainerEntrypoint, context)
: undefined,
EnableSageMakerMetricsTimeSeries: __expectBoolean(output.EnableSageMakerMetricsTimeSeries),
MetricDefinitions:
output.MetricDefinitions != null
Expand Down Expand Up @@ -33427,6 +33457,30 @@ const deserializeAws_json1_1TrafficRoutingConfig = (output: any, context: __Serd
} as any;
};

const deserializeAws_json1_1TrainingContainerArguments = (output: any, context: __SerdeContext): string[] => {
const retVal = (output || [])
.filter((e: any) => e != null)
.map((entry: any) => {
if (entry === null) {
return null as any;
}
return __expectString(entry) as any;
});
return retVal;
};

const deserializeAws_json1_1TrainingContainerEntrypoint = (output: any, context: __SerdeContext): string[] => {
const retVal = (output || [])
.filter((e: any) => e != null)
.map((entry: any) => {
if (entry === null) {
return null as any;
}
return __expectString(entry) as any;
});
return retVal;
};

const deserializeAws_json1_1TrainingEnvironmentMap = (output: any, context: __SerdeContext): Record<string, string> => {
return Object.entries(output).reduce((acc: Record<string, string>, [key, value]: [string, any]) => {
if (value === null) {
Expand Down
Loading

0 comments on commit 4978352

Please sign in to comment.