diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 71fd19991e..20f56a21cb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -179,7 +179,9 @@ private FunctionBuilder getFunctionBuilder( List sourceTypes = functionSignature.getParamTypeList(); List targetTypes = resolvedSignature.getKey().getParamTypeList(); FunctionBuilder funcBuilder = resolvedSignature.getValue(); - if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { + if (isCastFunction(functionName) + || FunctionSignature.isVarArgFunction(targetTypes) + || sourceTypes.equals(targetTypes)) { return funcBuilder; } return castArguments(sourceTypes, diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index 7081179162..a28fa7e0ad 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -50,7 +50,15 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { + if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) + && (unresolvedSignature.getParamTypeList().isEmpty() + || unresolvedSignature.getParamTypeList().size() > 9)) { + throw new ExpressionEvaluationException( + String.format("%s function expected 1-9 arguments, but got %d", + functionName, unresolvedSignature.getParamTypeList().size())); + } + if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) + && !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) { throw new ExpressionEvaluationException( String.format("%s function expected %s, but get %s", functionName, formatFunctions(functionBundle.keySet()), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java index adb1698386..0c59d71c25 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java @@ -5,6 +5,8 @@ package org.opensearch.sql.expression.function; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; + import java.util.List; import java.util.stream.Collectors; import lombok.EqualsAndHashCode; @@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) { || paramTypeList.size() != functionTypeList.size()) { return NOT_MATCH; } + // TODO: improve to support regular and array type mixed, ex. func(int,string,array) + if (isVarArgFunction(functionTypeList)) { + return EXACTLY_MATCH; + } int matchDegree = EXACTLY_MATCH; for (int i = 0; i < paramTypeList.size(); i++) { @@ -62,4 +68,11 @@ public String formatTypes() { .map(ExprType::typeName) .collect(Collectors.joining(",", "[", "]")); } + + /** + * util function - returns true if function has variable arguments. + */ + protected static boolean isVarArgFunction(List argTypes) { + return argTypes.size() == 1 && argTypes.get(0) == ARRAY; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 25eb25489c..e56c85a0c8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -6,24 +6,34 @@ package org.opensearch.sql.expression.text; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.function.FunctionDSL.define; import static org.opensearch.sql.expression.function.FunctionDSL.impl; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; - /** * The definition of text functions. * 1) have the clear interface for function define. @@ -141,16 +151,37 @@ private DefaultFunctionResolver upper() { } /** - * TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710 - * Extend to accept variable argument amounts. * Concatenates a list of Strings. * Supports following signatures: - * (STRING, STRING) -> STRING + * (STRING, STRING, ...., STRING) -> STRING */ private DefaultFunctionResolver concat() { - return define(BuiltinFunctionName.CONCAT.getName(), - impl(nullMissingHandling((str1, str2) -> - new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING)); + FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName(); + return define(concatFuncName, funcName -> + Pair.of( + new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)), + (funcProp, args) -> new FunctionExpression(funcName, args) { + @Override + public ExprValue valueOf(Environment valueEnv) { + List exprValues = args.stream() + .map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList()); + if (exprValues.stream().anyMatch(ExprValue::isMissing)) { + return ExprValueUtils.missingValue(); + } + if (exprValues.stream().anyMatch(ExprValue::isNull)) { + return ExprValueUtils.nullValue(); + } + return new ExprStringValue(exprValues.stream() + .map(ExprValue::stringValue) + .collect(Collectors.joining())); + } + + @Override + public ExprType type() { + return STRING; + } + } + )); } /** diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index baa299b60b..202c1bd0aa 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -9,8 +9,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -76,4 +80,53 @@ void resolve_function_not_match() { assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]", exception.getMessage()); } + + @Test + void resolve_varargs_function_signature_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING)); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); + + assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); + } + + @Test + void resolve_varargs_no_args_function_signature_not_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); + // Concat function with no arguments + when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList()); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); + + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> resolver.resolve(functionSignature)); + assertEquals("concat function expected 1-9 arguments, but got 0", + exception.getMessage()); + } + + @Test + void resolve_varargs_too_many_args_function_signature_not_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); + // Concat function with more than 9 arguments + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList + .of(STRING, STRING, STRING, STRING, STRING, + STRING, STRING, STRING, STRING, STRING)); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); + + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> resolver.resolve(functionSignature)); + assertEquals("concat function expected 1-9 arguments, but got 10", + exception.getMessage()); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index 515b436c82..54d2e5c400 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase { private static List> CONCAT_STRING_LISTS = ImmutableList.of( ImmutableList.of("hello", "world"), ImmutableList.of("123", "5325")); + private static List> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of( + ImmutableList.of("he", "llo", "wo", "rld", "!"), + ImmutableList.of("0", "123", "53", "25", "7")); interface SubstrSubstring { FunctionExpression getFunction(SubstringInfo strInfo); @@ -228,11 +231,13 @@ public void upper() { @Test void concat() { CONCAT_STRING_LISTS.forEach(this::testConcatString); + CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString); when(nullRef.type()).thenReturn(STRING); when(missingRef.type()).thenReturn(STRING); assertEquals(missingValue(), eval( DSL.concat(missingRef, DSL.literal("1")))); + // If any of the expressions is a NULL value, it returns NULL. assertEquals(nullValue(), eval( DSL.concat(nullRef, DSL.literal("1")))); assertEquals(missingValue(), eval( @@ -446,6 +451,22 @@ void testConcatString(List strings, String delim) { assertEquals(expected, eval(expression).stringValue()); } + void testConcatMultipleString(List strings) { + String expected = null; + if (strings.stream().noneMatch(Objects::isNull)) { + expected = String.join("", strings); + } + + FunctionExpression expression = DSL.concat( + DSL.literal(strings.get(0)), + DSL.literal(strings.get(1)), + DSL.literal(strings.get(2)), + DSL.literal(strings.get(3)), + DSL.literal(strings.get(4))); + assertEquals(STRING, expression.type()); + assertEquals(expected, eval(expression).stringValue()); + } + void testLengthString(String str) { FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str))); assertEquals(INTEGER, expression.type()); diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index f433845bb3..ab96075ac3 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2614,21 +2614,21 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together. +Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL. -Argument type: STRING, STRING +Argument type: STRING, STRING, ...., STRING Return type: STRING Example:: - os> SELECT CONCAT('hello', 'world') + os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null) fetched rows / total rows = 1/1 - +----------------------------+ - | CONCAT('hello', 'world') | - |----------------------------| - | helloworld | - +----------------------------+ + +--------------------------------------------+----------------------------+-------------------------+ + | CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) | + |--------------------------------------------+----------------------------+-------------------------| + | hello whole world! | helloworld | null | + +--------------------------------------------+----------------------------+-------------------------+ CONCAT_WS diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index 0503759cbd..9b7e69d985 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -14,21 +14,21 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together. +Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. -Argument type: STRING, STRING +Argument type: STRING, STRING, ...., STRING Return type: STRING Example:: - os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')` + os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')` fetched rows / total rows = 1/1 - +----------------------------+ - | CONCAT('hello', 'world') | - |----------------------------| - | helloworld | - +----------------------------+ + +----------------------------+--------------------------------------------+ + | CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') | + |----------------------------+--------------------------------------------| + | helloworld | hello whole world! | + +----------------------------+--------------------------------------------+ CONCAT_WS diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java index 7c48bceab0..024f190bee 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java @@ -99,8 +99,8 @@ public void testLtrim() throws IOException { @Test public void testConcat() throws IOException { - verifyQuery("concat", "", ", 'there'", - "hellothere", "worldthere", "helloworldthere"); + verifyQuery("concat", "", ", 'there', 'all', '!'", + "hellothereall!", "worldthereall!", "helloworldthereall!"); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java index 175cafd31e..94677354e4 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java @@ -108,6 +108,7 @@ public void testLtrim() throws IOException { @Test public void testConcat() throws IOException { + verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!"); verifyQuery("concat('hello', 'world')", "keyword", "helloworld"); verifyQuery("concat('', 'hello')", "keyword", "hello"); } diff --git a/integ-test/src/test/resources/correctness/expressions/text_functions.txt b/integ-test/src/test/resources/correctness/expressions/text_functions.txt index c2fd57c330..077cc82084 100644 --- a/integ-test/src/test/resources/correctness/expressions/text_functions.txt +++ b/integ-test/src/test/resources/correctness/expressions/text_functions.txt @@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column LOCATE('world', 'hello') as column LOCATE('world', 'helloworld', 7) as column REPLACE('helloworld', 'world', 'opensearch') as column -REPLACE('hello', 'world', 'opensearch') as column \ No newline at end of file +REPLACE('hello', 'world', 'opensearch') as column +CONCAT('hello', 'world') as column +CONCAT('hello ', 'whole ', 'world', '!') as column \ No newline at end of file