Skip to content

Commit

Permalink
Fix prompt passing for Bedrock by passing a single string prompt for …
Browse files Browse the repository at this point in the history
…Bedrock models. (opensearch-project#1476)

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Oct 11, 2023
1 parent cea1cd6 commit 91582da
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import java.util.List;
import java.util.Map;

import lombok.Builder;
import org.opensearch.ml.common.conversation.Interaction;

import lombok.AllArgsConstructor;
Expand All @@ -42,4 +44,5 @@ public class ChatCompletionInput {
private int timeoutInSeconds;
private String systemPrompt;
private String userInstructions;
private Llm.ModelProvider modelProvider;
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) {
}

@VisibleForTesting
void setMlClient(MachineLearningInternalClient mlClient) {
protected void setMlClient(MachineLearningInternalClient mlClient) {
this.mlClient = mlClient;
}

Expand All @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) {
@Override
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {

Map<String, String> inputParameters = new HashMap<>();
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build();
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
Expand All @@ -99,19 +87,65 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI

// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.

List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
}

protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
Map<String, String> inputParameters = new HashMap<>();

if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) {
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) {
inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()));
} else {
throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider());
}

log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap) {

List<Object> answers = null;
List<String> errors = null;
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));

if (provider == ModelProvider.OPENAI) {
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
}
} else if (provider == ModelProvider.BEDROCK) {
String response = (String) dataAsMap.get("completion");
if (response != null) {
answers = List.of(response);
} else {
// Error
}
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
}

return new ChatCompletionOutput(answers, errors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@
*/
public interface Llm {

enum ModelProvider {
OPENAI,
BEDROCK
}

ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;
Expand All @@ -27,6 +29,14 @@
*/
public class LlmIOUtil {

private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model";
private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages";
private static final String CONNECTOR_OUTPUT_CHOICES = "choices";
private static final String CONNECTOR_OUTPUT_MESSAGE = "message";
private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
private static final String CONNECTOR_OUTPUT_ERROR = "error";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
String question,
Expand Down Expand Up @@ -57,7 +67,10 @@ public static ChatCompletionInput createChatCompletionInput(
List<String> contexts,
int timeoutInSeconds
) {

return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions);
Llm.ModelProvider provider = Llm.ModelProvider.OPENAI;
if (llmModel != null && llmModel.startsWith("bedrock/")) {
provider = Llm.ModelProvider.BEDROCK;
}
return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions, provider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;

import org.apache.commons.text.StringEscapeUtils;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -62,6 +63,8 @@ public static String getChatCompletionPrompt(String question, List<Interaction>
return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts);
}

// TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of
// future prompt template management work.
public static String getChatCompletionPrompt(
String systemPrompt,
String userInstructions,
Expand All @@ -87,6 +90,46 @@ enum ChatRole {
}
}

static final String NEWLINE = "\\n";

public static String buildSingleStringPrompt (
String systemPrompt,
String userInstructions,
String question,
List<Interaction> chatHistory,
List<String> contexts
) {
if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) {
systemPrompt = DEFAULT_SYSTEM_PROMPT;
}

StringBuilder bldr = new StringBuilder();
bldr.append(systemPrompt);
bldr.append(NEWLINE);
bldr.append(userInstructions);
bldr.append(NEWLINE);

for (int i = 0; i < contexts.size(); i++) {
bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
bldr.append(NEWLINE);
}
if (!chatHistory.isEmpty()) {
// The oldest interaction first
// Collections.reverse(chatHistory);
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
Collections.reverse(messages);
messages.forEach(m -> {
bldr.append(m.toString());
bldr.append(NEWLINE);
});

}
bldr.append("QUESTION: " + question);
bldr.append(NEWLINE);

return bldr.toString();
}

@VisibleForTesting
static String buildMessageParameter(
String systemPrompt,
Expand Down Expand Up @@ -163,6 +206,8 @@ public static Messages fromInteractions(final List<Interaction> interactions) {
}
}

// TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle
// vendor specific messages.
static class Message {

private final static String MESSAGE_FIELD_ROLE = "role";
Expand Down Expand Up @@ -199,5 +244,10 @@ public void setContent(String content) {
public JsonObject toJson() {
return json;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public void testCtor() {
Collections.emptyList(),
0,
systemPrompt,
userInstructions
userInstructions,
Llm.ModelProvider.OPENAI
);

assertNotNull(input);
Expand Down Expand Up @@ -70,7 +71,7 @@ public void testGettersSetters() {
)
);
List<String> contexts = List.of("result1", "result2");
ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions);
ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions, Llm.ModelProvider.OPENAI);
assertEquals(model, input.getModel());
assertEquals(question, input.getQuestion());
assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ public void testChatCompletionApi() throws Exception {
Collections.emptyList(),
0,
"prompt",
"instructions"
"instructions",
Llm.ModelProvider.OPENAI
);
ChatCompletionOutput output = connector.doChatCompletion(input);
verify(mlClient, times(1)).predict(any(), captor.capture());
Expand Down Expand Up @@ -141,7 +142,8 @@ public void testChatCompletionThrowingError() throws Exception {
Collections.emptyList(),
0,
"prompt",
"instructions"
"instructions",
Llm.ModelProvider.OPENAI
);
ChatCompletionOutput output = connector.doChatCompletion(input);
verify(mlClient, times(1)).predict(any(), captor.capture());
Expand Down

0 comments on commit 91582da

Please sign in to comment.