Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] Handle BWC for bedrock converse API #3173

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
import java.util.List;
import java.util.Objects;

import org.opensearch.Version;
b4sjoo marked this conversation as resolved.
Show resolved Hide resolved
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -86,6 +88,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {

public static final int SIZE_NULL_VALUE = -1;

static final Version MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES = CommonValue.VERSION_2_18_0;

@Setter
@Getter
private String conversationId;
Expand Down Expand Up @@ -185,16 +189,27 @@ public GenerativeQAParameters(
}

public GenerativeQAParameters(StreamInput input) throws IOException {
Version version = input.getVersion();
this.conversationId = input.readOptionalString();
this.llmModel = input.readOptionalString();
this.llmQuestion = input.readOptionalString();

// this string was made optional in 2.18
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
this.llmQuestion = input.readOptionalString();
} else {
this.llmQuestion = input.readString();
}

this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
this.contextSize = input.readInt();
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();
this.llmMessages.addAll(input.readList(MessageBlock::new));

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
this.llmMessages.addAll(input.readList(MessageBlock::new));
}
}

@Override
Expand Down Expand Up @@ -246,16 +261,27 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params

@Override
public void writeTo(StreamOutput out) throws IOException {
Version version = out.getVersion();
out.writeOptionalString(conversationId);
out.writeOptionalString(llmModel);
out.writeOptionalString(llmQuestion);

// this string was made optional in 2.18
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
out.writeOptionalString(llmQuestion);
} else {
out.writeString(llmQuestion);
}

out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
out.writeInt(contextSize);
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);
out.writeList(llmMessages);

if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
out.writeList(llmMessages);
}
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.EOFException;
Expand All @@ -30,6 +31,7 @@
import java.util.Map;

import org.junit.Assert;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand Down Expand Up @@ -119,9 +121,15 @@ public void testMiscMethods() throws IOException {
assertNotEquals(builder1, builder2);
assertNotEquals(builder1.hashCode(), builder2.hashCode());

StreamOutput so = mock(StreamOutput.class);
builder1.writeTo(so);
verify(so, times(6)).writeOptionalString(any());
StreamOutput so1 = mock(StreamOutput.class);
when(so1.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
builder1.writeTo(so1);
verify(so1, times(6)).writeOptionalString(any());

StreamOutput so2 = mock(StreamOutput.class);
when(so2.getVersion()).thenReturn(Version.V_2_17_0);
builder1.writeTo(so2);
verify(so2, times(5)).writeOptionalString(any());
}

public void testParse() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import java.util.List;
import java.util.Map;

import org.opensearch.Version;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -179,6 +182,48 @@ public void testWriteTo() throws IOException {
assertTrue(timeout == intValues.get(2));
}

public void testWriteToBwcBedrockConverse() throws IOException {
String conversationId = "a";
String llmModel = "b";
String llmQuestion = "c";
String systemPrompt = "s";
String userInstructions = "u";
int contextSize = 1;
int interactionSize = 2;
int timeout = 10;
String llmResponseField = "text";
GenerativeQAParameters expected = new GenerativeQAParameters(
conversationId,
llmModel,
llmQuestion,
systemPrompt,
userInstructions,
contextSize,
interactionSize,
timeout,
llmResponseField,
messageList
);

// Version.2_18_0 (MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)
BytesStreamOutput output = new BytesStreamOutput();
output.setVersion(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
expected.writeTo(output);
StreamInput input = output.bytes().streamInput();
input.setVersion(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
GenerativeQAParameters actual = new GenerativeQAParameters(input);
assertEquals(expected, actual);

// Version.2_17_0 (LlmMessages should be empty list)
output = new BytesStreamOutput();
output.setVersion(Version.V_2_17_0);
expected.writeTo(output);
input = output.bytes().streamInput();
input.setVersion(Version.V_2_17_0);
actual = new GenerativeQAParameters(input);
assertTrue(actual.getLlmMessages().isEmpty());
}

public void testMisc() {
String conversationId = "a";
String llmModel = "b";
Expand Down
Loading