forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Conversation and Update Interaction APIs
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
274759a
commit 814cd77
Showing
15 changed files
with
536 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
.../src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class UpdateConversationAction extends ActionType<UpdateResponse> { | ||
public static final UpdateConversationAction INSTANCE = new UpdateConversationAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/update"; | ||
|
||
private UpdateConversationAction() { | ||
super(NAME, UpdateResponse::new); | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
...src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
import java.util.Map; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
@Getter | ||
public class UpdateConversationRequest extends ActionRequest { | ||
String conversationId; | ||
Map<String, Object> updateContent; | ||
|
||
@Builder | ||
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) { | ||
this.conversationId = conversationId; | ||
this.updateContent = updateContent; | ||
} | ||
|
||
public UpdateConversationRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversationId = in.readString(); | ||
this.updateContent = in.readMap(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.conversationId); | ||
out.writeMap(this.getUpdateContent()); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.conversationId == null) { | ||
exception = addValidationError("conversation id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static UpdateConversationRequest parse(XContentParser parser, String conversationId) throws IOException { | ||
Map<String, Object> dataAsMap = null; | ||
dataAsMap = parser.map(); | ||
|
||
return UpdateConversationRequest.builder().conversationId(conversationId).updateContent(dataAsMap).build(); | ||
} | ||
|
||
public static UpdateConversationRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof UpdateConversationRequest) { | ||
return (UpdateConversationRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new UpdateConversationRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into UpdateConversationRequest", e); | ||
} | ||
} | ||
} |
68 changes: 68 additions & 0 deletions
68
.../java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.DocWriteResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.action.update.UpdateRequest; | ||
import org.opensearch.action.update.UpdateResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.ConversationalIndexConstants; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> { | ||
Client client; | ||
|
||
@Inject | ||
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { | ||
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) { | ||
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); | ||
String conversationId = updateConversationRequest.getConversationId(); | ||
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); | ||
updateRequest.doc(updateConversationRequest.getUpdateContent()); | ||
updateRequest.docAsUpsert(true); | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); | ||
} catch (Exception e) { | ||
log.error("Failed to update Conversation for conversation id {}. Details {}:", conversationId, e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
private ActionListener<UpdateResponse> getUpdateResponseListener( | ||
String conversationId, | ||
ActionListener<UpdateResponse> actionListener, | ||
ThreadContext.StoredContext context | ||
) { | ||
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { | ||
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { | ||
log.info("Failed to update the Conversation with ID: {}", conversationId); | ||
actionListener.onResponse(updateResponse); | ||
return; | ||
} | ||
log.info("Successfully updated the Conversation with ID: {}", conversationId); | ||
actionListener.onResponse(updateResponse); | ||
}, exception -> { | ||
log.error("Failed to update ML Conversation with ID {}. Details: {}", conversationId, exception); | ||
actionListener.onFailure(exception); | ||
}), context::restore); | ||
} | ||
} |
19 changes: 19 additions & 0 deletions
19
...y/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class UpdateInteractionAction extends ActionType<UpdateResponse> { | ||
public static final UpdateInteractionAction INSTANCE = new UpdateInteractionAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/update"; | ||
|
||
private UpdateInteractionAction() { | ||
super(NAME, UpdateResponse::new); | ||
} | ||
|
||
} |
84 changes: 84 additions & 0 deletions
84
.../src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
import java.util.Map; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
@Getter | ||
public class UpdateInteractionRequest extends ActionRequest { | ||
String interactionId; | ||
Map<String, Object> updateContent; | ||
|
||
@Builder | ||
public UpdateInteractionRequest(String interactionId, Map<String, Object> updateContent) { | ||
this.interactionId = interactionId; | ||
this.updateContent = updateContent; | ||
} | ||
|
||
public UpdateInteractionRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.interactionId = in.readString(); | ||
this.updateContent = in.readMap(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.interactionId); | ||
out.writeMap(this.getUpdateContent()); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.interactionId == null) { | ||
exception = addValidationError("interaction id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static UpdateInteractionRequest parse(XContentParser parser, String interactionId) throws IOException { | ||
Map<String, Object> dataAsMap = null; | ||
dataAsMap = parser.map(); | ||
|
||
return UpdateInteractionRequest.builder().interactionId(interactionId).updateContent(dataAsMap).build(); | ||
} | ||
|
||
public static UpdateInteractionRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof UpdateInteractionRequest) { | ||
return (UpdateInteractionRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new UpdateInteractionRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into UpdateInteractionRequest", e); | ||
} | ||
} | ||
|
||
} |
68 changes: 68 additions & 0 deletions
68
...n/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.DocWriteResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.action.update.UpdateRequest; | ||
import org.opensearch.action.update.UpdateResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.ConversationalIndexConstants; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> { | ||
Client client; | ||
|
||
@Inject | ||
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { | ||
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new); | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) { | ||
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request); | ||
String interactionId = updateInteractionRequest.getInteractionId(); | ||
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); | ||
updateRequest.doc(updateInteractionRequest.getUpdateContent()); | ||
updateRequest.docAsUpsert(true); | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context)); | ||
} catch (Exception e) { | ||
log.error("Failed to update Interaction for interaction id {}. Details {}:", interactionId, e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
private ActionListener<UpdateResponse> getUpdateResponseListener( | ||
String interactionId, | ||
ActionListener<UpdateResponse> actionListener, | ||
ThreadContext.StoredContext context | ||
) { | ||
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { | ||
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { | ||
log.info("Failed to update the interaction with ID: {}", interactionId); | ||
actionListener.onResponse(updateResponse); | ||
return; | ||
} | ||
log.info("Successfully updated the interaction with ID: {}", interactionId); | ||
actionListener.onResponse(updateResponse); | ||
}, exception -> { | ||
log.error("Failed to update ML interaction with ID {}. Details: {}", interactionId, exception); | ||
actionListener.onFailure(exception); | ||
}), context::restore); | ||
} | ||
} |
Oops, something went wrong.