From 468300cf0c2f65bd270c0db06399419630425110 Mon Sep 17 00:00:00 2001 From: Rithin Pullela Date: Mon, 9 Dec 2024 13:10:43 -0800 Subject: [PATCH] Enhance Interaction handling: validate fields, optimize storage, and improve error handling - Throw an error when an unknown field is provided in `CreateConversation` or `CreateInteraction`. - Skip saving empty fields in interactions to optimize storage and improve efficiency. - Throw an exception if all fields in an interaction are empty or null. - Add unit test cases to cover the new validation, error handling, and storage optimization logic. --- .../CreateConversationRequest.java | 25 ++++++-- .../CreateInteractionRequest.java | 13 +++- .../ml/memory/index/InteractionsIndex.java | 41 ++++++------ .../CreateConversationRequestTests.java | 25 ++++++++ .../CreateInteractionRequestTests.java | 64 +++++++++++++++++++ 5 files changed, 143 insertions(+), 25 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..668226e8d4 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,27 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for(String key : body.keySet()){ + switch(key){ + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + throw new IllegalArgumentException("Invalid field [" + key + "] found in request body"); + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..5679b67ca0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -166,11 +166,20 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro tracenum = parser.intValue(false); break; default: - parser.skipChildren(); - break; + throw new IllegalArgumentException("Invalid field [" + fieldName + "] found in request body"); } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) && + (prompt == null || prompt.trim().isEmpty()) && + (response == null || response.trim().isEmpty()) && + (origin == null || origin.trim().isEmpty()) && + (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException("At least one of the following parameters must be non-empty: " + + "input, prompt_template, response, origin, additional_info"); + } + return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index e8cf4bd7ae..5d9810659e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -161,27 +162,31 @@ public void createInteraction( if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parintid); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } IndexRequest request = Requests .indexRequest(INTERACTIONS_INDEX_NAME) .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber + sourceMap ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..53469d9091 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -132,4 +133,28 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + + public void testRestRequest_UnknownFields_ThenFail() throws IOException { + String name = "test-name"; + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray( + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, "unknown_field", "some value")) + ), + MediaTypeRegistry.JSON + ) + .build(); + + try{ + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + fail("Expected IllegalArgumentException due to unknown field"); + } + catch(OpenSearchParseException e){ + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } + catch( Exception e){ + fail("Expected OpenSearchParseException due to unknown field, got " + e.getClass().getName()); + } + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..4578c6b0b4 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,67 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testFromRestRequest_UnknownFields_ThenFail() throws IOException { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + Collections.singletonMap("metadata", "some meta"), + "unknown_field", + "some value" + ); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try{ + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to unknown field"); + } + catch (IllegalArgumentException e){ + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } + catch (Exception e){ + fail("Expected IllegalArgumentException due to unknown field, got " + e.getClass().getName()); + } + } + + public void testFromRestRequest_AllFieldsEmpty_ThenFail() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try{ + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to all fields empty"); + } + catch (IllegalArgumentException e){ + assertEquals(e.getMessage(), "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info"); + } + catch (Exception e){ + fail("Expected IllegalArgumentException due to all fields empty, got " + e.getClass().getName()); + } + } + + + }