diff --git a/integ-test/src/test/java/org/opensearch/sql/jdbc/CursorIT.java b/integ-test/src/test/java/org/opensearch/sql/jdbc/CursorIT.java index 325c81107f..e2b6287191 100644 --- a/integ-test/src/test/java/org/opensearch/sql/jdbc/CursorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/jdbc/CursorIT.java @@ -22,6 +22,7 @@ import java.sql.ResultSet; import java.sql.Statement; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import lombok.SneakyThrows; import org.json.JSONObject; @@ -115,6 +116,8 @@ public void select_all_no_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -133,6 +136,8 @@ public void select_count_all_no_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -151,6 +156,8 @@ public void select_all_small_table_big_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -169,6 +176,8 @@ public void select_all_small_table_small_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -187,6 +196,8 @@ public void select_all_big_table_small_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -205,6 +216,8 @@ public void select_all_big_table_big_cursor() { var restResponse = executeRestQuery(query, null); assertEquals(rows, restResponse.getInt("total")); + var restPrettyResponse = executeRestQuery(query, null, Map.of("pretty", "true")); + assertEquals(rows, restPrettyResponse.getInt("total")); } } @@ -217,6 +230,12 @@ private static String getConnectionString() { @SneakyThrows protected JSONObject executeRestQuery(String query, @Nullable Integer fetch_size) { + return executeRestQuery(query, fetch_size, Map.of()); + } + + @SneakyThrows + protected JSONObject executeRestQuery( + String query, @Nullable Integer fetch_size, Map params) { Request request = new Request("POST", QUERY_API_ENDPOINT); if (fetch_size != null) { request.setJsonEntity( @@ -224,6 +243,7 @@ protected JSONObject executeRestQuery(String query, @Nullable Integer fetch_size } else { request.setJsonEntity(String.format("{ \"query\": \"%s\" }", query)); } + request.addParameters(params); RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); restOptionsBuilder.addHeader("Content-Type", "application/json"); diff --git a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java index 4e902cb67d..1d17610fb4 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java +++ b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java @@ -10,6 +10,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Stream; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -29,6 +30,7 @@ public class SQLQueryRequest { Set.of("query", "fetch_size", "parameters", QUERY_FIELD_CURSOR); private static final String QUERY_PARAMS_FORMAT = "format"; private static final String QUERY_PARAMS_SANITIZE = "sanitize"; + private static final String QUERY_PARAMS_PRETTY = "pretty"; /** JSON payload in REST request. */ private final JSONObject jsonContent; @@ -79,19 +81,21 @@ public SQLQueryRequest( * @return true if supported. */ public boolean isSupported() { - var noCursor = !isCursor(); - var noQuery = query == null; - var noUnsupportedParams = - params.isEmpty() || (params.size() == 1 && params.containsKey(QUERY_PARAMS_FORMAT)); - var noContent = jsonContent == null || jsonContent.isEmpty(); - - return ((!noCursor - && noQuery - && noUnsupportedParams - && noContent) // if cursor is given, but other things - || (noCursor && !noQuery)) // or if cursor is not given, but query - && isOnlySupportedFieldInPayload() // and request has supported fields only - && isSupportedFormat(); // and request is in supported format + boolean hasCursor = isCursor(); + boolean hasQuery = query != null; + boolean hasContent = jsonContent != null && !jsonContent.isEmpty(); + + Predicate supportedParams = Set.of(QUERY_PARAMS_FORMAT, QUERY_PARAMS_PRETTY)::contains; + boolean hasUnsupportedParams = + (!params.isEmpty()) + && params.keySet().stream().dropWhile(supportedParams).findAny().isPresent(); + + boolean validCursor = hasCursor && !hasQuery && !hasUnsupportedParams && !hasContent; + boolean validQuery = !hasCursor && hasQuery; + + return (validCursor || validQuery) // It's a valid cursor or a valid query + && isOnlySupportedFieldInPayload() // and request must contain supported fields only + && isSupportedFormat(); // and request must be a supported format } private boolean isCursor() { diff --git a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java index 2b64b13b35..b569a89a2e 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java @@ -106,6 +106,42 @@ public void should_support_cursor_request() { () -> assertTrue(cursorRequest.isSupported())); } + @Test + public void should_support_cursor_request_with_supported_parameters() { + SQLQueryRequest fetchSizeRequest = + SQLQueryRequestBuilder.request("SELECT 1") + .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 5}") + .build(); + + SQLQueryRequest cursorRequest = + SQLQueryRequestBuilder.request(null) + .cursor("abcdefgh...") + .params(Map.of("format", "csv", "pretty", "true")) + .build(); + + assertAll( + () -> assertTrue(fetchSizeRequest.isSupported()), + () -> assertTrue(cursorRequest.isSupported())); + } + + @Test + public void should_not_support_cursor_request_with_unsupported_parameters() { + SQLQueryRequest fetchSizeRequest = + SQLQueryRequestBuilder.request("SELECT 1") + .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 5}") + .build(); + + SQLQueryRequest cursorRequest = + SQLQueryRequestBuilder.request(null) + .cursor("abcdefgh...") + .params(Map.of("one", "two")) + .build(); + + assertAll( + () -> assertTrue(fetchSizeRequest.isSupported()), + () -> assertFalse(cursorRequest.isSupported())); + } + @Test public void should_support_cursor_close_request() { SQLQueryRequest closeRequest =