Skip to content

Commit

Permalink
Fixing bug setting index when parsing Google Vertex AI results (#117287
Browse files Browse the repository at this point in the history
…) (#117358)

* Using record ID as index value when parsing Google Vertex AI rerank results

* Update docs/changelog/117287.yaml

* PR feedback
  • Loading branch information
ymao1 authored Nov 22, 2024
1 parent d95c003 commit 9552422
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117287.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117287
summary: Fixing bug setting index when parsing Google Vertex AI results
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
public class GoogleVertexAiRerankResponseEntity {

private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Google Vertex AI rerank response";
private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Google Vertex AI rerank "
+ "response but received [%s]";

/**
* Parses the Google Vertex AI rerank response.
Expand Down Expand Up @@ -109,14 +111,27 @@ private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser)
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
}

return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
if (parsedRankedDoc.id == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName()));
}

try {
return new RankedDocsResults.RankedDoc(
Integer.parseInt(parsedRankedDoc.id),
parsedRankedDoc.score,
parsedRankedDoc.content
);
} catch (NumberFormatException e) {
throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
}
});
}

private record RankedDoc(@Nullable Float score, @Nullable String content) {
private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) {

private static final ParseField CONTENT = new ParseField("content");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField ID = new ParseField("id");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
"google_vertex_ai_rerank_response",
true,
Expand All @@ -126,6 +141,7 @@ private record RankedDoc(@Nullable Float score, @Nullable String content) {
static {
PARSER.declareString(Builder::setContent, CONTENT);
PARSER.declareFloat(Builder::setScore, SCORE);
PARSER.declareString(Builder::setId, ID);
}

public static RankedDoc parse(XContentParser parser) {
Expand All @@ -137,6 +153,7 @@ private static final class Builder {

private String content;
private Float score;
private String id;

private Builder() {}

Expand All @@ -150,8 +167,13 @@ public Builder setContent(String content) {
return this;
}

public Builder setId(String id) {
this.id = id;
return this;
}

public RankedDoc build() {
return new RankedDoc(score, content);
return new RankedDoc(score, content, id);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);

assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"))));
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
}

public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
Expand Down Expand Up @@ -68,7 +68,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException

assertThat(
parsedResults.getRankedDocs(),
is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
);
}

Expand Down Expand Up @@ -161,4 +161,37 @@ public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {

assertThat(thrownException.getMessage(), is("Failed to find required field [score] in Google Vertex AI rerank response"));
}

public void testFromResponse_FailsWhenIDFieldIsNotInteger() {
String responseJson = """
{
"records": [
{
"id": "abcd",
"title": "title 2",
"content": "content 2",
"score": 0.97
},
{
"id": "1",
"title": "title 1",
"content": "content 1",
"score": 0.96
}
]
}
""";

var thrownException = expectThrows(
IllegalStateException.class,
() -> GoogleVertexAiRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);

assertThat(
thrownException.getMessage(),
is("Expected numeric value for record ID field in Google Vertex AI rerank response but received [abcd]")
);
}
}

0 comments on commit 9552422

Please sign in to comment.