Skip to content

Commit

Permalink
Update Conversation and Update Interaction APIs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Nov 21, 2023
1 parent 274759a commit 814cd77
Show file tree
Hide file tree
Showing 15 changed files with 536 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ public class ActionConstants {
public final static String SUCCESS_FIELD = "success";

private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for list conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for update conversations */
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update";
/** path for put interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
Expand All @@ -70,6 +73,8 @@ public class ActionConstants {
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for update interactions */
public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class UpdateConversationAction extends ActionType<UpdateResponse> {
public static final UpdateConversationAction INSTANCE = new UpdateConversationAction();
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/update";

private UpdateConversationAction() {
super(NAME, UpdateResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.Getter;

@Getter
public class UpdateConversationRequest extends ActionRequest {
String conversationId;
Map<String, Object> updateContent;

@Builder
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {
this.conversationId = conversationId;
this.updateContent = updateContent;
}

public UpdateConversationRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
this.updateContent = in.readMap();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.conversationId);
out.writeMap(this.getUpdateContent());
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.conversationId == null) {
exception = addValidationError("conversation id can't be null", exception);
}

return exception;
}

public static UpdateConversationRequest parse(XContentParser parser, String conversationId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();

return UpdateConversationRequest.builder().conversationId(conversationId).updateContent(dataAsMap).build();
}

public static UpdateConversationRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof UpdateConversationRequest) {
return (UpdateConversationRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new UpdateConversationRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into UpdateConversationRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;

@Inject
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
this.client = client;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
updateRequest.doc(updateConversationRequest.getUpdateContent());
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation for conversation id {}. Details {}:", conversationId, e);
listener.onFailure(e);
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String conversationId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Failed to update the Conversation with ID: {}", conversationId);
actionListener.onResponse(updateResponse);
return;
}
log.info("Successfully updated the Conversation with ID: {}", conversationId);
actionListener.onResponse(updateResponse);
}, exception -> {
log.error("Failed to update ML Conversation with ID {}. Details: {}", conversationId, exception);
actionListener.onFailure(exception);
}), context::restore);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class UpdateInteractionAction extends ActionType<UpdateResponse> {
public static final UpdateInteractionAction INSTANCE = new UpdateInteractionAction();
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/update";

private UpdateInteractionAction() {
super(NAME, UpdateResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.Getter;

@Getter
public class UpdateInteractionRequest extends ActionRequest {
String interactionId;
Map<String, Object> updateContent;

@Builder
public UpdateInteractionRequest(String interactionId, Map<String, Object> updateContent) {
this.interactionId = interactionId;
this.updateContent = updateContent;
}

public UpdateInteractionRequest(StreamInput in) throws IOException {
super(in);
this.interactionId = in.readString();
this.updateContent = in.readMap();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.interactionId);
out.writeMap(this.getUpdateContent());
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.interactionId == null) {
exception = addValidationError("interaction id can't be null", exception);
}

return exception;
}

public static UpdateInteractionRequest parse(XContentParser parser, String interactionId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();

return UpdateInteractionRequest.builder().interactionId(interactionId).updateContent(dataAsMap).build();
}

public static UpdateInteractionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof UpdateInteractionRequest) {
return (UpdateInteractionRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new UpdateInteractionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into UpdateInteractionRequest", e);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;

@Inject
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
this.client = client;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
String interactionId = updateInteractionRequest.getInteractionId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
updateRequest.doc(updateInteractionRequest.getUpdateContent());
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
} catch (Exception e) {
log.error("Failed to update Interaction for interaction id {}. Details {}:", interactionId, e);
listener.onFailure(e);
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String interactionId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Failed to update the interaction with ID: {}", interactionId);
actionListener.onResponse(updateResponse);
return;
}
log.info("Successfully updated the interaction with ID: {}", interactionId);
actionListener.onResponse(updateResponse);
}, exception -> {
log.error("Failed to update ML interaction with ID {}. Details: {}", interactionId, exception);
actionListener.onFailure(exception);
}), context::restore);
}
}
Loading

0 comments on commit 814cd77

Please sign in to comment.