From 91582daedc74387047e08ab5eef2e00e80aa0417 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 11 Oct 2023 13:39:01 -0700 Subject: [PATCH] Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. (https://github.com/opensearch-project/ml-commons/issues/1476) Signed-off-by: Austin Lee --- .../generative/llm/ChatCompletionInput.java | 3 + .../generative/llm/DefaultLlmImpl.java | 80 +++++++++++++------ .../questionanswering/generative/llm/Llm.java | 5 ++ .../generative/llm/LlmIOUtil.java | 17 +++- .../generative/prompt/PromptUtil.java | 50 ++++++++++++ .../llm/ChatCompletionInputTests.java | 5 +- .../generative/llm/DefaultLlmImplTests.java | 6 +- 7 files changed, 137 insertions(+), 29 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 85e1173875..61e7ecae76 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -18,7 +18,9 @@ package org.opensearch.searchpipelines.questionanswering.generative.llm; import java.util.List; +import java.util.Map; +import lombok.Builder; import org.opensearch.ml.common.conversation.Interaction; import lombok.AllArgsConstructor; @@ -42,4 +44,5 @@ public class ChatCompletionInput { private int timeoutInSeconds; private String systemPrompt; private String userInstructions; + private Llm.ModelProvider modelProvider; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index beef67b9e9..45b46f2a1a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) { } @VisibleForTesting - void setMlClient(MachineLearningInternalClient mlClient) { + protected void setMlClient(MachineLearningInternalClient mlClient) { this.mlClient = mlClient; } @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) { @Override public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { - Map inputParameters = new HashMap<>(); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - String messages = PromptUtil - .getChatCompletionPrompt( - chatCompletionInput.getSystemPrompt(), - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts() - ); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); - log.info("Messages to LLM: {}", messages); - MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); + MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000); @@ -99,19 +87,65 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. - List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap); + } + + protected Map getInputParameters(ChatCompletionInput chatCompletionInput) { + Map inputParameters = new HashMap<>(); + + if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) { + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); + String messages = PromptUtil.getChatCompletionPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + log.info("Messages to LLM: {}", messages); + } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) { + inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts())); + } else { + throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider()); + } + + log.info("LLM input parameters: {}", inputParameters.toString()); + return inputParameters; + } + + protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map dataAsMap) { + List answers = null; List errors = null; - if (choices == null) { - Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); - errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + + if (provider == ModelProvider.OPENAI) { + List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + if (choices == null) { + Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); + errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + } else { + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + } + } else if (provider == ModelProvider.BEDROCK) { + String response = (String) dataAsMap.get("completion"); + if (response != null) { + answers = List.of(response); + } else { + // Error + } } else { - Map firstChoiceMap = (Map) choices.get(0); - log.info("Choices: {}", firstChoiceMap.toString()); - Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); - log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); - answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider); } + return new ChatCompletionOutput(answers, errors); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index e850561066..faf136d550 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -22,5 +22,10 @@ */ public interface Llm { + enum ModelProvider { + OPENAI, + BEDROCK + } + ChatCompletionOutput doChatCompletion(ChatCompletionInput input); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index fb95ed63bf..b8fcf48096 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -17,7 +17,9 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; @@ -27,6 +29,14 @@ */ public class LlmIOUtil { + private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model"; + private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages"; + private static final String CONNECTOR_OUTPUT_CHOICES = "choices"; + private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; + private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; + private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; + private static final String CONNECTOR_OUTPUT_ERROR = "error"; + public static ChatCompletionInput createChatCompletionInput( String llmModel, String question, @@ -57,7 +67,10 @@ public static ChatCompletionInput createChatCompletionInput( List contexts, int timeoutInSeconds ) { - - return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions); + Llm.ModelProvider provider = Llm.ModelProvider.OPENAI; + if (llmModel != null && llmModel.startsWith("bedrock/")) { + provider = Llm.ModelProvider.BEDROCK; + } + return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions, provider); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 9c57ffbf0f..f38e694c78 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Locale; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.core.common.Strings; @@ -62,6 +63,8 @@ public static String getChatCompletionPrompt(String question, List return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); } + // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of + // future prompt template management work. public static String getChatCompletionPrompt( String systemPrompt, String userInstructions, @@ -87,6 +90,46 @@ enum ChatRole { } } + static final String NEWLINE = "\\n"; + + public static String buildSingleStringPrompt ( + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts + ) { + if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { + systemPrompt = DEFAULT_SYSTEM_PROMPT; + } + + StringBuilder bldr = new StringBuilder(); + bldr.append(systemPrompt); + bldr.append(NEWLINE); + bldr.append(userInstructions); + bldr.append(NEWLINE); + + for (int i = 0; i < contexts.size(); i++) { + bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + bldr.append(NEWLINE); + } + if (!chatHistory.isEmpty()) { + // The oldest interaction first + // Collections.reverse(chatHistory); + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> { + bldr.append(m.toString()); + bldr.append(NEWLINE); + }); + + } + bldr.append("QUESTION: " + question); + bldr.append(NEWLINE); + + return bldr.toString(); + } + @VisibleForTesting static String buildMessageParameter( String systemPrompt, @@ -163,6 +206,8 @@ public static Messages fromInteractions(final List interactions) { } } + // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle + // vendor specific messages. static class Message { private final static String MESSAGE_FIELD_ROLE = "role"; @@ -199,5 +244,10 @@ public void setContent(String content) { public JsonObject toJson() { return json; } + + @Override + public String toString() { + return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content); + } } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 0e34dd0bf1..403291f27c 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -41,7 +41,8 @@ public void testCtor() { Collections.emptyList(), 0, systemPrompt, - userInstructions + userInstructions, + Llm.ModelProvider.OPENAI ); assertNotNull(input); @@ -70,7 +71,7 @@ public void testGettersSetters() { ) ); List contexts = List.of("result1", "result2"); - ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions); + ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions, Llm.ModelProvider.OPENAI); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 218bd65ec9..551a0e68bc 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -111,7 +111,8 @@ public void testChatCompletionApi() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); @@ -141,7 +142,8 @@ public void testChatCompletionThrowingError() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture());