Skip to content

Commit

Permalink
fix(stepfunctions-tasks): add back BedrockInvokeModel to use JsonPath (
Browse files Browse the repository at this point in the history
…aws#31325)

Reverts aws#31308

After discussion with the team, we are working on the "roll-forward fix" PR for this issue and not going ahead with this revert
aws#31305

Reverting to resolve the existing conflicts.
  • Loading branch information
shikha372 authored Sep 5, 2024
1 parent 03ebca8 commit 5b059b9
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 19 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@
]
]
}
},
{
"Action": [
"s3:GetObject",
"s3:PutObject"
],
"Effect": "Allow",
"Resource": {
"Fn::Join": [
"",
[
"arn:",
{
"Ref": "AWS::Partition"
},
":s3:::*"
]
]
}
}
],
"Version": "2012-10-17"
Expand Down Expand Up @@ -72,7 +91,19 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"End\":true,\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"Next\":\"Prompt3\",\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
":states:::bedrock:invokeModel\",\"Parameters\":{\"ModelId\":\"arn:",
{
"Ref": "AWS::Partition"
},
":bedrock:",
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt3\":{\"End\":true,\"Type\":\"Task\",\"InputPath\":\"$.names\",\"OutputPath\":\"$.names\",\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
Expand All @@ -84,7 +115,7 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}}},\"TimeoutSeconds\":30}"
"::foundation-model/amazon.titan-text-express-v1\",\"Input\":{\"S3Uri.$\":\"$.names\"},\"Output\":{\"S3Uri.$\":\"$.names\"}}}},\"TimeoutSeconds\":30}"
]
]
},
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ const prompt2 = new BedrockInvokeModel(stack, 'Prompt2', {
resultPath: '$',
});

const chain = sfn.Chain.start(prompt1).next(prompt2);
const prompt3 = new BedrockInvokeModel(stack, 'Prompt3', {
model,
inputPath: sfn.JsonPath.stringAt('$.names'),
outputPath: sfn.JsonPath.stringAt('$.names'),
});

const chain = sfn.Chain.start(prompt1).next(prompt2).next(prompt3);

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(chain),
Expand Down
21 changes: 21 additions & 0 deletions packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,27 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
names: sfn.JsonPath.stringAt('$.Body.results[0].outputText'),
},
});
```
### Using Input Path

Provide S3 URI as an input or output path to invoke a model

```ts

import * as bedrock from 'aws-cdk-lib/aws-bedrock';

const model = bedrock.FoundationModel.fromFoundationModelId(
this,
'Model',
bedrock.FoundationModelIdentifier.AMAZON_TITAN_TEXT_G1_EXPRESS_V1,
);

const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
model,
inputPath: sfn.JsonPath.stringAt('$.prompt'),
outputPath: sfn.JsonPath.stringAt('$.prompt'),
});

```

You can apply a guardrail to the invocation by setting `guardrail`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {

constructor(scope: Construct, id: string, private readonly props: BedrockInvokeModelProps) {
super(scope, id, props);

this.integrationPattern = props.integrationPattern ?? sfn.IntegrationPattern.REQUEST_RESPONSE;

validatePatternSupported(this.integrationPattern, BedrockInvokeModel.SUPPORTED_INTEGRATION_PATTERNS);

const isBodySpecified = props.body !== undefined;
const isInputSpecified = props.input !== undefined && props.input.s3Location !== undefined;
//Either specific props.input with bucket name and object key or input s3 path
const isInputSpecified = (props.input !== undefined && props.input.s3Location !== undefined) || (props.inputPath !== undefined);

if (isBodySpecified && isInputSpecified) {
throw new Error('Either `body` or `input` must be specified, but not both.');
Expand All @@ -171,7 +173,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
}),
];

if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
if (this.props.inputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
Expand All @@ -188,7 +204,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
);
}

if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
if (this.props.outputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
Expand Down Expand Up @@ -241,10 +271,10 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
Body: this.props.body?.value,
Input: this.props.input?.s3Location ? {
S3Uri: `s3://${this.props.input.s3Location.bucketName}/${this.props.input.s3Location.objectKey}`,
} : undefined,
} : this.props.inputPath ? { S3Uri: this.props.inputPath } : undefined,
Output: this.props.output?.s3Location ? {
S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`,
} : undefined,
} : this.props.outputPath ? { S3Uri: this.props.outputPath }: undefined,
GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier,
GuardrailVersion: this.props.guardrail?.guardrailVersion,
Trace: this.props.traceEnabled === undefined
Expand All @@ -254,5 +284,6 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
: 'DISABLED',
}),
};
}
};
}

Loading

0 comments on commit 5b059b9

Please sign in to comment.