Skip to content

Commit

Permalink
Enhance Interaction handling: validate fields, optimize storage, and …
Browse files Browse the repository at this point in the history
…improve error handling

- Throw an error when an unknown field is provided in `CreateConversation` or `CreateInteraction`.
- Skip saving empty fields in interactions to optimize storage and improve efficiency.
- Throw an exception if all fields in an interaction are empty or null.
- Add unit test cases to cover the new validation, error handling, and storage optimization logic.
  • Loading branch information
rithin-pullela-aws committed Dec 9, 2024
1 parent 1a659c8 commit 468300c
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,27 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
}
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
String name = null;
String applicationType = null;
Map<String, String> additionalInfo = null;

for(String key : body.keySet()){
switch(key){
case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD:
name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD);
break;
case APPLICATION_TYPE_FIELD:
applicationType = (String) body.get(APPLICATION_TYPE_FIELD);
break;
case META_ADDITIONAL_INFO_FIELD:
additionalInfo = (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD);
break;
default:
throw new IllegalArgumentException("Invalid field [" + key + "] found in request body");
}
}
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
return new CreateConversationRequest(name, applicationType, additionalInfo);
} else {
return new CreateConversationRequest();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,20 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro
tracenum = parser.intValue(false);
break;
default:
parser.skipChildren();
break;
throw new IllegalArgumentException("Invalid field [" + fieldName + "] found in request body");
}
}

boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) &&
(prompt == null || prompt.trim().isEmpty()) &&
(response == null || response.trim().isEmpty()) &&
(origin == null || origin.trim().isEmpty()) &&
(addinf == null || addinf.isEmpty());
if (allFieldsEmpty) {
throw new IllegalArgumentException("At least one of the following parameters must be non-empty: " +
"input, prompt_template, response, origin, additional_info");
}

return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -161,27 +162,31 @@ public void createInteraction(
if (indexExists) {
this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
Map<String, Object> sourceMap = new HashMap<>();
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId);
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp);
sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parintid);
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber);

if (input != null && !input.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
}
if (promptTemplate != null && !promptTemplate.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
}
if (response != null && !response.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
}
if (origin != null && !origin.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
}
if (additionalInfo != null && !additionalInfo.isEmpty()) {
sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
IndexRequest request = Requests
.indexRequest(INTERACTIONS_INDEX_NAME)
.source(
ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD,
origin,
ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD,
conversationId,
ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD,
input,
ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD,
promptTemplate,
ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD,
response,
ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD,
additionalInfo,
ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD,
timestamp,
ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD,
parintid,
ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD,
traceNumber
sourceMap
);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -132,4 +133,28 @@ public void testRestRequest_WithAdditionalInfo() throws IOException {
Assert.assertEquals("value1", request.getAdditionalInfos().get("key1"));
Assert.assertEquals(123, request.getAdditionalInfos().get("key2"));
}

public void testRestRequest_UnknownFields_ThenFail() throws IOException {
String name = "test-name";
RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(
new BytesArray(
gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, "unknown_field", "some value"))
),
MediaTypeRegistry.JSON
)
.build();

try{
CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req);
fail("Expected IllegalArgumentException due to unknown field");
}
catch(OpenSearchParseException e){
assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body");
}
catch( Exception e){
fail("Expected OpenSearchParseException due to unknown field, got " + e.getClass().getName());
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
Expand Down Expand Up @@ -153,4 +154,67 @@ public void testFromRestRequest_Trace() throws IOException {
assert (request.getParentIid().equals("parentId"));
assert (request.getTraceNumber().equals(1));
}

public void testFromRestRequest_UnknownFields_ThenFail() throws IOException {
Map<String, Object> params = Map
.of(
ActionConstants.INPUT_FIELD,
"input",
ActionConstants.PROMPT_TEMPLATE_FIELD,
"pt",
ActionConstants.AI_RESPONSE_FIELD,
"response",
ActionConstants.RESPONSE_ORIGIN_FIELD,
"origin",
ActionConstants.ADDITIONAL_INFO_FIELD,
Collections.singletonMap("metadata", "some meta"),
"unknown_field",
"some value"
);

RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.MEMORY_ID, "cid"))
.withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON)
.build();

try{
CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest);
fail("Expected IllegalArgumentException due to unknown field");
}
catch (IllegalArgumentException e){
assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body");
}
catch (Exception e){
fail("Expected IllegalArgumentException due to unknown field, got " + e.getClass().getName());
}
}

public void testFromRestRequest_AllFieldsEmpty_ThenFail() throws IOException {
Map<String, Object> params = new HashMap<>();

params.put(ActionConstants.INPUT_FIELD, "");
params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null);
params.put(ActionConstants.AI_RESPONSE_FIELD, " ");
params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null);
params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap());

RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.MEMORY_ID, "cid"))
.withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON)
.build();

try{
CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest);
fail("Expected IllegalArgumentException due to all fields empty");
}
catch (IllegalArgumentException e){
assertEquals(e.getMessage(), "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info");
}
catch (Exception e){
fail("Expected IllegalArgumentException due to all fields empty, got " + e.getClass().getName());
}
}



}

0 comments on commit 468300c

Please sign in to comment.