diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java b/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java index 9b2035a9..e02203eb 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java +++ b/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java @@ -14,7 +14,6 @@ import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.rest.RestResponse; /** * This interface defines methods which an extension REST handler (action) must provide. @@ -31,10 +30,11 @@ public interface ExtensionRestHandler { * Handles REST Requests forwarded from OpenSearch for a configured route on an extension. * Parameters are components of the {@link RestRequest} received from a user. * This method corresponds to the {@link BaseRestHandler#prepareRequest} method. + * As in that method, consumed parameters must be tracked and returned in the response. * * @param method A REST method. * @param uri The URI to handle. - * @return A {@link RestResponse} to the request. + * @return An {@link ExtensionRestResponse} to the request. */ - RestResponse handleRequest(Method method, String uri); + ExtensionRestResponse handleRequest(Method method, String uri); } diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java new file mode 100644 index 00000000..c623c3ec --- /dev/null +++ b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import java.util.List; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestStatus; + +/** + * A subclass of {@link BytesRestResponse} which processes the consumed parameters into a custom header. + */ +public class ExtensionRestResponse extends BytesRestResponse { + + /** + * Key passed in {@link BytesRestResponse} headers to identify parameters consumed by the handler. For internal use. + */ + static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; + + /** + * Creates a new response based on {@link XContentBuilder}. + * + * @param status The REST status. + * @param builder The builder for the response. + * @param consumedParams Parameters consumed by the handler. + */ + public ExtensionRestResponse(RestStatus status, XContentBuilder builder, List consumedParams) { + super(status, builder); + addConsumedParamHeader(consumedParams); + } + + /** + * Creates a new plain text response. + * + * @param status The REST status. + * @param content A plain text response string. + * @param consumedParams Parameters consumed by the handler. + */ + public ExtensionRestResponse(RestStatus status, String content, List consumedParams) { + super(status, content); + addConsumedParamHeader(consumedParams); + } + + /** + * Creates a new plain text response. + * + * @param status The REST status. + * @param contentType The content type of the response string. + * @param content A response string. + * @param consumedParams Parameters consumed by the handler. + */ + public ExtensionRestResponse(RestStatus status, String contentType, String content, List consumedParams) { + super(status, contentType, content); + addConsumedParamHeader(consumedParams); + } + + /** + * Creates a binary response. + * + * @param status The REST status. + * @param contentType The content type of the response bytes. + * @param content Response bytes. + * @param consumedParams Parameters consumed by the handler. + */ + public ExtensionRestResponse(RestStatus status, String contentType, byte[] content, List consumedParams) { + super(status, contentType, content); + addConsumedParamHeader(consumedParams); + } + + /** + * Creates a binary response. + * + * @param status The REST status. + * @param contentType The content type of the response bytes. + * @param content Response bytes. + * @param consumedParams Parameters consumed by the handler. + */ + public ExtensionRestResponse(RestStatus status, String contentType, BytesReference content, List consumedParams) { + super(status, contentType, content); + addConsumedParamHeader(consumedParams); + } + + private void addConsumedParamHeader(List consumedParams) { + consumedParams.stream().forEach(p -> addHeader(CONSUMED_PARAMS_KEY, p)); + } +} diff --git a/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java index 320864a3..cc10ee23 100644 --- a/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java +++ b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java @@ -7,14 +7,14 @@ */ package org.opensearch.sdk.sample.helloworld.rest; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.rest.RestResponse; import org.opensearch.sdk.ExtensionRestHandler; +import org.opensearch.sdk.ExtensionRestResponse; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; import static org.opensearch.rest.RestRequest.Method.GET; @@ -37,19 +37,29 @@ public List routes() { } @Override - public RestResponse handleRequest(Method method, String uri) { + public ExtensionRestResponse handleRequest(Method method, String uri) { + // We need to track which parameters are consumed to pass back to OpenSearch + List consumedParams = new ArrayList<>(); if (Method.GET.equals(method) && "/hello".equals(uri)) { - return new BytesRestResponse(OK, String.format(GREETING, worldName)); + return new ExtensionRestResponse(OK, String.format(GREETING, worldName), consumedParams); } else if (Method.PUT.equals(method) && uri.startsWith("/hello/")) { + // Placeholder code here for parameters in named wildcard paths + // Full implementation based on params() will be implemented as part of + // https://github.com/opensearch-project/opensearch-sdk-java/issues/111 String name = uri.substring("/hello/".length()); + consumedParams.add("name"); try { worldName = URLDecoder.decode(name, StandardCharsets.UTF_8); } catch (IllegalArgumentException e) { - return new BytesRestResponse(BAD_REQUEST, e.getMessage()); + return new ExtensionRestResponse(BAD_REQUEST, e.getMessage(), consumedParams); } - return new BytesRestResponse(OK, "Updated the world's name to " + worldName); + return new ExtensionRestResponse(OK, "Updated the world's name to " + worldName, consumedParams); } - return new BytesRestResponse(NOT_FOUND, "Extension REST action improperly configured to handle " + method.name() + " " + uri); + return new ExtensionRestResponse( + NOT_FOUND, + "Extension REST action improperly configured to handle " + method.name() + " " + uri, + consumedParams + ); } } diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java b/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java index 8939ff2e..17e2a7e1 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java @@ -6,7 +6,6 @@ import org.junit.jupiter.api.Test; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.rest.RestResponse; import org.opensearch.test.OpenSearchTestCase; public class TestExtensionRestPathRegistry extends OpenSearchTestCase { @@ -20,7 +19,7 @@ public List routes() { } @Override - public RestResponse handleRequest(Method method, String uri) { + public ExtensionRestResponse handleRequest(Method method, String uri) { return null; } }; @@ -31,7 +30,7 @@ public List routes() { } @Override - public RestResponse handleRequest(Method method, String uri) { + public ExtensionRestResponse handleRequest(Method method, String uri) { return null; } }; @@ -42,7 +41,7 @@ public List routes() { } @Override - public RestResponse handleRequest(Method method, String uri) { + public ExtensionRestResponse handleRequest(Method method, String uri) { return null; } }; diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java new file mode 100644 index 00000000..e11aed5e --- /dev/null +++ b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java @@ -0,0 +1,111 @@ +package org.opensearch.sdk; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.test.OpenSearchTestCase; + +import static org.opensearch.rest.BytesRestResponse.TEXT_CONTENT_TYPE; +import static org.opensearch.rest.RestStatus.ACCEPTED; +import static org.opensearch.rest.RestStatus.OK; +import static org.opensearch.sdk.ExtensionRestResponse.CONSUMED_PARAMS_KEY; + +public class TestExtensionRestResponse extends OpenSearchTestCase { + + private static final String OCTET_CONTENT_TYPE = "application/octet-stream"; + private static final String JSON_CONTENT_TYPE = "application/json; charset=UTF-8"; + + private String testText; + private byte[] testBytes; + private List testConsumedParams; + + @Override + @BeforeEach + public void setUp() throws Exception { + super.setUp(); + testText = "plain text"; + testBytes = new byte[] { 1, 2 }; + testConsumedParams = List.of("foo", "bar"); + } + + @Test + public void testConstructorWithBuilder() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.field("status", ACCEPTED); + builder.endObject(); + ExtensionRestResponse response = new ExtensionRestResponse(OK, builder, testConsumedParams); + + assertEquals(OK, response.status()); + assertEquals(JSON_CONTENT_TYPE, response.contentType()); + assertEquals("{\"status\":\"ACCEPTED\"}", response.content().utf8ToString()); + List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); + for (String param : consumedParams) { + assertTrue(testConsumedParams.contains(param)); + } + } + + @Test + public void testConstructorWithPlainText() { + ExtensionRestResponse response = new ExtensionRestResponse(OK, testText, testConsumedParams); + + assertEquals(OK, response.status()); + assertEquals(TEXT_CONTENT_TYPE, response.contentType()); + assertEquals(testText, response.content().utf8ToString()); + List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); + for (String param : consumedParams) { + assertTrue(testConsumedParams.contains(param)); + } + } + + @Test + public void testConstructorWithText() { + ExtensionRestResponse response = new ExtensionRestResponse(OK, TEXT_CONTENT_TYPE, testText, testConsumedParams); + + assertEquals(OK, response.status()); + assertEquals(TEXT_CONTENT_TYPE, response.contentType()); + assertEquals(testText, response.content().utf8ToString()); + + List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); + for (String param : consumedParams) { + assertTrue(testConsumedParams.contains(param)); + } + } + + @Test + public void testConstructorWithByteArray() { + ExtensionRestResponse response = new ExtensionRestResponse(OK, OCTET_CONTENT_TYPE, testBytes, testConsumedParams); + + assertEquals(OK, response.status()); + assertEquals(OCTET_CONTENT_TYPE, response.contentType()); + assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); + List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); + for (String param : consumedParams) { + assertTrue(testConsumedParams.contains(param)); + } + } + + @Test + public void testConstructorWithBytesReference() { + ExtensionRestResponse response = new ExtensionRestResponse( + OK, + OCTET_CONTENT_TYPE, + BytesReference.fromByteBuffer(ByteBuffer.wrap(testBytes, 0, 2)), + testConsumedParams + ); + + assertEquals(OK, response.status()); + assertEquals(OCTET_CONTENT_TYPE, response.contentType()); + assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); + List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); + for (String param : consumedParams) { + assertTrue(testConsumedParams.contains(param)); + } + } +}