Skip to content

Commit

Permalink
Fix prompt passing for Bedrock by passing a single string prompt for … (
Browse files Browse the repository at this point in the history
#1490)

* Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. (#1476)

Signed-off-by: Austin Lee <[email protected]>

* Add unit tests, apply Spotless.

Signed-off-by: Austin Lee <[email protected]>

* Check if systemPrompt is null.

Signed-off-by: Austin Lee <[email protected]>

* Address review comments.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Oct 11, 2023
1 parent f4446cb commit e18f249
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 30 deletions.
4 changes: 4 additions & 0 deletions search-processors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ GET /<index>/_search\?search_pipeline\=<search pipeline name>
}
```

To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic".

The latest RAG processor has been tested with OpenAI's GPT 3.5 and 4 models and Bedrock's Anthropic Claude (v2) model only.

## Retrieval Augmented Generation response
```
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ public class ChatCompletionInput {
private int timeoutInSeconds;
private String systemPrompt;
private String userInstructions;
private Llm.ModelProvider modelProvider;
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) {
}

@VisibleForTesting
void setMlClient(MachineLearningInternalClient mlClient) {
protected void setMlClient(MachineLearningInternalClient mlClient) {
this.mlClient = mlClient;
}

Expand All @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) {
@Override
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {

Map<String, String> inputParameters = new HashMap<>();
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build();
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
Expand All @@ -99,19 +87,83 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI

// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.

List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
}

protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
Map<String, String> inputParameters = new HashMap<>();

if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) {
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) {
inputParameters
.put(
"inputs",
PromptUtil
.buildSingleStringPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
)
);
} else {
throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider());
}

log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap) {

List<Object> answers = null;
List<String> errors = null;
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));

if (provider == ModelProvider.OPENAI) {
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log
.info(
"role: {}, content: {}",
message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE),
message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)
);
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
}
} else if (provider == ModelProvider.BEDROCK) {
String response = (String) dataAsMap.get("completion");
if (response != null) {
answers = List.of(response);
} else {
Map error = (Map) dataAsMap.get("error");
if (error != null) {
errors = List.of((String) error.get("message"));
} else {
errors = List.of("Unknown error or response.");
}
}
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
}

return new ChatCompletionOutput(answers, errors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,11 @@
*/
public interface Llm {

// TODO Ensure the current implementation works with all models supported by Bedrock.
enum ModelProvider {
OPENAI,
BEDROCK
}

ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
public class LlmIOUtil {

private static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
String question,
Expand Down Expand Up @@ -57,7 +59,19 @@ public static ChatCompletionInput createChatCompletionInput(
List<String> contexts,
int timeoutInSeconds
) {

return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions);
Llm.ModelProvider provider = Llm.ModelProvider.OPENAI;
if (llmModel != null && llmModel.startsWith(BEDROCK_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK;
}
return new ChatCompletionInput(
llmModel,
question,
chatHistory,
contexts,
timeoutInSeconds,
systemPrompt,
userInstructions,
provider
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;

import org.apache.commons.text.StringEscapeUtils;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -54,6 +55,8 @@ public class PromptUtil {

private static final String roleUser = "user";

private static final String NEWLINE = "\\n";

public static String getQuestionRephrasingPrompt(String originalQuestion, List<Interaction> chatHistory) {
return null;
}
Expand All @@ -62,6 +65,8 @@ public static String getChatCompletionPrompt(String question, List<Interaction>
return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts);
}

// TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of
// future prompt template management work.
public static String getChatCompletionPrompt(
String systemPrompt,
String userInstructions,
Expand All @@ -87,6 +92,48 @@ enum ChatRole {
}
}

public static String buildSingleStringPrompt(
String systemPrompt,
String userInstructions,
String question,
List<Interaction> chatHistory,
List<String> contexts
) {
if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) {
systemPrompt = DEFAULT_SYSTEM_PROMPT;
}

StringBuilder bldr = new StringBuilder();

if (!Strings.isNullOrEmpty(systemPrompt)) {
bldr.append(systemPrompt);
bldr.append(NEWLINE);
}
if (!Strings.isNullOrEmpty(userInstructions)) {
bldr.append(userInstructions);
bldr.append(NEWLINE);
}

for (int i = 0; i < contexts.size(); i++) {
bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
bldr.append(NEWLINE);
}
if (!chatHistory.isEmpty()) {
// The oldest interaction first
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
Collections.reverse(messages);
messages.forEach(m -> {
bldr.append(m.toString());
bldr.append(NEWLINE);
});

}
bldr.append("QUESTION: " + question);
bldr.append(NEWLINE);

return bldr.toString();
}

@VisibleForTesting
static String buildMessageParameter(
String systemPrompt,
Expand All @@ -110,7 +157,6 @@ static String buildMessageParameter(
}
if (!chatHistory.isEmpty()) {
// The oldest interaction first
// Collections.reverse(chatHistory);
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
Collections.reverse(messages);
messages.forEach(m -> messageArray.add(m.toJson()));
Expand Down Expand Up @@ -163,6 +209,8 @@ public static Messages fromInteractions(final List<Interaction> interactions) {
}
}

// TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle
// vendor specific messages.
static class Message {

private final static String MESSAGE_FIELD_ROLE = "role";
Expand All @@ -186,6 +234,7 @@ public Message(ChatRole chatRole, String content) {
}

public void setChatRole(ChatRole chatRole) {
this.chatRole = chatRole;
json.remove(MESSAGE_FIELD_ROLE);
json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName()));
}
Expand All @@ -199,5 +248,10 @@ public void setContent(String content) {
public JsonObject toJson() {
return json;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public void testCtor() {
Collections.emptyList(),
0,
systemPrompt,
userInstructions
userInstructions,
Llm.ModelProvider.OPENAI
);

assertNotNull(input);
Expand Down Expand Up @@ -70,7 +71,16 @@ public void testGettersSetters() {
)
);
List<String> contexts = List.of("result1", "result2");
ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions);
ChatCompletionInput input = new ChatCompletionInput(
model,
question,
history,
contexts,
0,
systemPrompt,
userInstructions,
Llm.ModelProvider.OPENAI
);
assertEquals(model, input.getModel());
assertEquals(question, input.getQuestion());
assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId());
Expand Down
Loading

0 comments on commit e18f249

Please sign in to comment.