Skip to content

Commit

Permalink
fix tutorials for cross encoder models on Amazon Bedrock based on rev…
Browse files Browse the repository at this point in the history
…iew comments on #3278

Signed-off-by: tkykenmt <[email protected]>
  • Loading branch information
tkykenmt committed Dec 25, 2024
1 parent 7bf5526 commit 9482b47
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# Topic

[Reranking pipeline](https://opensearch.org/docs/latest/search-plugins/search-relevance/reranking-search-results/) is a feature released in OpenSearch 2.12.
It can rerank search results, providing a relevance score for each document in the search results with respect to the search query.
The relevance score is calculated by a cross-encoder model.
[Reranking pipeline](https://opensearch.org/docs/latest/search-plugins/search-relevance/reranking-search-results/) is a feature released in OpenSearch 2.12. It can rerank search results, providing a relevance score with respect to the search query for each matching document. The relevance score is calculated by a cross-encoder model.

This tutorial explains how to use the [Amazon Rerank 1.0 model in Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/rerank-supported.html) in a reranking pipeline.
This tutorial illustrates using the [Amazon Rerank 1.0 model in Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/rerank-supported.html) in a reranking pipeline.

Note: Replace the placeholders that start with `your_` with your own values.

# Steps

## 0. Test the model on Amazon Bedrock
You can perform a reranking test with the following code.
You can perform a reranking test using the following code.

```python
import json
Expand Down Expand Up @@ -40,11 +38,39 @@ response = bedrock_runtime_client.invoke_model(
body=body
)
results = json.loads(response.get('body').read())["results"]
print(json.dumps(sorted(results, key=lambda x: x['index']),indent=2))
print(json.dumps(results, indent=2))
```

The reranking result is ordering by the highest score first:
```
[
{
"index": 2,
"relevance_score": 0.7711548724998493
},
{
"index": 0,
"relevance_score": 0.0025114635138098534
},
{
"index": 1,
"relevance_score": 2.4876490010363496e-05
},
{
"index": 3,
"relevance_score": 6.339210403977635e-06
}
]
```

The reranking results are as follows:

You can sort the result by index number.

```python
print(json.dumps(sorted(results, key=lambda x: x['index']),indent=2))
```

The results are as follows:
```
[
{
Expand Down Expand Up @@ -82,14 +108,16 @@ POST /_plugins/_ml/connectors/_create
"session_token": "your_session_token"
},
"parameters": {
"region": "your_bedrock_model_region_like_us-west-2",
"service_name": "bedrock"
"service_name": "bedrock",
"service_code": "bedrock-runtime",
"region": "your_bedrock_model_region_like_us",
"model_name": "amazon.rerank-v1:0"
},
"actions": [
{
"action_type": "predict",
"action_type": "PREDICT",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/amazon.rerank-v1:0/invoke",
"url": "https://${parameters.service_code}.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"x-amz-content-sha256": "required",
"content-type": "application/json"
Expand All @@ -99,20 +127,25 @@ POST /_plugins/_ml/connectors/_create
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('\"');
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('\"');
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ \"query\": \"' + query_text + '\", \"documents\": ' + textDocsBuilder.toString() + ' }';
return '{\"parameters\": ' + parameters + '}';
""",
"request_body": "{ \"query\": \"${parameters.query}\", \"documents\": ${parameters.documents} }",
def parameters = '{ "query": "' + query_text + '", "documents": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"documents": ${parameters.documents},
"query": "${parameters.query}"
}
""",
"post_process_function": """
if (params.result == null || params.result.length > 0) {
if (params.results == null || params.results.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.results;
Expand All @@ -123,8 +156,8 @@ POST /_plugins/_ml/connectors/_create
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<relevance_scores.length; i++) {
resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');
resultBuilder.append('\"data\": [');
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(relevance_scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
Expand All @@ -139,7 +172,7 @@ POST /_plugins/_ml/connectors/_create
}
```

If using the Amazon Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html)
If using the Amazon Opensearch Service, you can provide an IAM role ARN that allows access to the Amazon Bedrock service. For more information, see [AWS documentation](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html):

```json
POST /_plugins/_ml/connectors/_create
Expand All @@ -152,14 +185,16 @@ POST /_plugins/_ml/connectors/_create
"roleArn": "your_role_arn_which_allows_access_to_bedrock_model"
},
"parameters": {
"region": "your_bedrock_model_region_like_us-west-2",
"service_name": "bedrock"
"service_name": "bedrock",
"service_code": "bedrock-runtime",
"region": "your_bedrock_model_region_like_us",
"model_name": "amazon.rerank-v1:0"
},
"actions": [
{
"action_type": "predict",
"action_type": "PREDICT",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/amazon.rerank-v1:0/invoke",
"url": "https://${parameters.service_code}.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"x-amz-content-sha256": "required",
"content-type": "application/json"
Expand All @@ -169,20 +204,25 @@ POST /_plugins/_ml/connectors/_create
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('\"');
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('\"');
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ \"query\": \"' + query_text + '\", \"documents\": ' + textDocsBuilder.toString() + ' }';
return '{\"parameters\": ' + parameters + '}';
""",
"request_body": "{ \"query\": \"${parameters.query}\", \"documents\": ${parameters.documents} }",
def parameters = '{ "query": "' + query_text + '", "documents": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"documents": ${parameters.documents},
"query": "${parameters.query}"
}
""",
"post_process_function": """
if (params.result == null || params.result.length > 0) {
if (params.results == null || params.results.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.results;
Expand All @@ -193,8 +233,8 @@ POST /_plugins/_ml/connectors/_create
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<relevance_scores.length; i++) {
resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');
resultBuilder.append('\"data\": [');
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(relevance_scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
Expand Down Expand Up @@ -237,7 +277,7 @@ POST _plugins/_ml/models/your_model_id/_predict
}
```

Each item in the array comprises a query_text and a text_docs string, separated by a .
Each item in the array comprises a `query_text` and a `text_docs` string, separated by a ` , `.

Alternatively, you can test the model as follows:
```json
Expand All @@ -253,11 +293,15 @@ POST _plugins/_ml/_predict/text_similarity/your_model_id
}
```

The connector `pre_process_function` transforms the input into the format required by parameters shown previously.
The connector `pre_process_function` transforms the input into the format required by the previously shown parameters.

By default, Amazon Bedrock Rerank API output has the following format:
By default, the Amazon Bedrock Rerank API output has the following format:
```json
[
{
"index": 2,
"relevance_score": 0.7711548724998493
},
{
"index": 0,
"relevance_score": 0.0025114635138098534
Expand All @@ -266,18 +310,14 @@ By default, Amazon Bedrock Rerank API output has the following format:
"index": 1,
"relevance_score": 2.4876490010363496e-05
},
{
"index": 2,
"relevance_score": 0.7711548724998493
},
{
"index": 3,
"relevance_score": 6.339210403977635e-06
}
]
```

The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret, and orders result by index. This adapted format is as follows:
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret, and orders the results by index. This adapted format is as follows:
```json
{
"inference_results": [
Expand Down Expand Up @@ -332,7 +372,7 @@ The connector `post_process_function` transforms the model's output into a forma

Explanation of the response:
1. The response contains two `similarity` outputs. For each `similarity` output, the `data` array contains a relevance score of each document against the query.
2. The `similarity` outputs are provided in the order of the input documents; the first result of similarity pertains to the first document.
2. The `similarity` outputs are provided in the order of the input documents; the first similarity result pertains to the first document.

## 2. Reranking pipeline
### 2.1 Ingest test data
Expand Down Expand Up @@ -368,7 +408,7 @@ PUT /_search/pipeline/rerank_pipeline_sagemaker
}
```

Note: if you provide multiple filed names in `document_fields`, the values of all fields are first concatenated and then reranking is performed.
Note: if you provide multiple field names in `document_fields`, the values of all fields are first concatenated and then reranking is performed.

### 2.2 Test reranking

Expand Down Expand Up @@ -595,7 +635,7 @@ The first document in the response is `"Washington, D.C. (also known as simply W
}
```

Note: You can avoid writing the query twice by using query_text_path instead of query_text, as follows:
Note: You can avoid writing the query twice by using the `query_text_path` instead of `query_text` as follows:
```json
POST my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
{
Expand Down
Loading

0 comments on commit 9482b47

Please sign in to comment.