Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable concat() string function to support multiple string arguments #1279

Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private FunctionBuilder getFunctionBuilder(
List<ExprType> sourceTypes = functionSignature.getParamTypeList();
List<ExprType> targetTypes = resolvedSignature.getKey().getParamTypeList();
FunctionBuilder funcBuilder = resolvedSignature.getValue();
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
if (isCastFunction(functionName)
|| FunctionSignature.isVarArgFunction(targetTypes)
|| sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
return castArguments(sourceTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unreso
functionSignature));
}
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) {
if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())
&& (unresolvedSignature.getParamTypeList().isEmpty()
|| unresolvedSignature.getParamTypeList().size() > 9)) {
throw new ExpressionEvaluationException(
String.format("%s function expected 1-9 arguments, but got %d",
functionName, unresolvedSignature.getParamTypeList().size()));
}
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())
&& !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) {
throw new ExpressionEvaluationException(
String.format("%s function expected %s, but get %s", functionName,
formatFunctions(functionBundle.keySet()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;

import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
Expand Down Expand Up @@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) {
|| paramTypeList.size() != functionTypeList.size()) {
return NOT_MATCH;
}
// TODO: improve to support regular and array type mixed, ex. func(int,string,array)
if (isVarArgFunction(functionTypeList)) {
return EXACTLY_MATCH;
}

int matchDegree = EXACTLY_MATCH;
for (int i = 0; i < paramTypeList.size(); i++) {
Expand All @@ -62,4 +68,11 @@ public String formatTypes() {
.map(ExprType::typeName)
.collect(Collectors.joining(",", "[", "]"));
}

/**
* util function - returns true if function has variable arguments.
*/
protected static boolean isVarArgFunction(List<ExprType> argTypes) {
return argTypes.size() == 1 && argTypes.get(0) == ARRAY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,34 @@

package org.opensearch.sql.expression.text;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.define;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.FunctionSignature;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.expression.function.SerializableTriFunction;


/**
* The definition of text functions.
* 1) have the clear interface for function define.
Expand Down Expand Up @@ -141,16 +151,37 @@ private DefaultFunctionResolver upper() {
}

/**
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
* Extend to accept variable argument amounts.
* Concatenates a list of Strings.
* Supports following signatures:
* (STRING, STRING) -> STRING
* (STRING, STRING, ...., STRING) -> STRING
*/
private DefaultFunctionResolver concat() {
return define(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling((str1, str2) ->
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName();
return define(concatFuncName, funcName ->
Pair.of(
new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)),
Comment on lines +160 to +162
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dai-chen for the idea!

(funcProp, args) -> new FunctionExpression(funcName, args) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
List<ExprValue> exprValues = args.stream()
.map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList());
if (exprValues.stream().anyMatch(ExprValue::isMissing)) {
return ExprValueUtils.missingValue();
}
if (exprValues.stream().anyMatch(ExprValue::isNull)) {
return ExprValueUtils.nullValue();
}
return new ExprStringValue(exprValues.stream()
.map(ExprValue::stringValue)
.collect(Collectors.joining()));
}

@Override
public ExprType type() {
return STRING;
}
}
));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -76,4 +80,53 @@ void resolve_function_not_match() {
assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]",
exception.getMessage());
}

@Test
void resolve_varargs_function_signature_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING));
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue());
}

@Test
void resolve_varargs_no_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with no arguments
when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList());

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 0",
exception.getMessage());
}

@Test
void resolve_varargs_too_many_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with more than 9 arguments
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList
.of(STRING, STRING, STRING, STRING, STRING,
STRING, STRING, STRING, STRING, STRING));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 10",
exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase {
private static List<List<String>> CONCAT_STRING_LISTS = ImmutableList.of(
ImmutableList.of("hello", "world"),
ImmutableList.of("123", "5325"));
private static List<List<String>> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of(
ImmutableList.of("he", "llo", "wo", "rld", "!"),
ImmutableList.of("0", "123", "53", "25", "7"));

interface SubstrSubstring {
FunctionExpression getFunction(SubstringInfo strInfo);
Expand Down Expand Up @@ -228,11 +231,13 @@ public void upper() {
@Test
void concat() {
CONCAT_STRING_LISTS.forEach(this::testConcatString);
CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString);

when(nullRef.type()).thenReturn(STRING);
when(missingRef.type()).thenReturn(STRING);
assertEquals(missingValue(), eval(
DSL.concat(missingRef, DSL.literal("1"))));
// If any of the expressions is a NULL value, it returns NULL.
assertEquals(nullValue(), eval(
DSL.concat(nullRef, DSL.literal("1"))));
assertEquals(missingValue(), eval(
Expand Down Expand Up @@ -446,6 +451,22 @@ void testConcatString(List<String> strings, String delim) {
assertEquals(expected, eval(expression).stringValue());
}

void testConcatMultipleString(List<String> strings) {
String expected = null;
if (strings.stream().noneMatch(Objects::isNull)) {
expected = String.join("", strings);
}

FunctionExpression expression = DSL.concat(
DSL.literal(strings.get(0)),
DSL.literal(strings.get(1)),
DSL.literal(strings.get(2)),
DSL.literal(strings.get(3)),
DSL.literal(strings.get(4)));
assertEquals(STRING, expression.type());
assertEquals(expected, eval(expression).stringValue());
}

void testLengthString(String str) {
FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str)));
assertEquals(INTEGER, expression.type());
Expand Down
16 changes: 8 additions & 8 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2614,21 +2614,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> SELECT CONCAT('hello', 'world')
os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null)
fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+--------------------------------------------+----------------------------+-------------------------+
| CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) |
|--------------------------------------------+----------------------------+-------------------------|
| hello whole world! | helloworld | null |
+--------------------------------------------+----------------------------+-------------------------+


CONCAT_WS
Expand Down
16 changes: 8 additions & 8 deletions docs/user/ppl/functions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')`
os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a NULL test for PPL?

fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+----------------------------+--------------------------------------------+
| CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') |
|----------------------------+--------------------------------------------|
| helloworld | hello whole world! |
+----------------------------+--------------------------------------------+


CONCAT_WS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat", "", ", 'there'",
"hellothere", "worldthere", "helloworldthere");
verifyQuery("concat", "", ", 'there', 'all', '!'",
"hellothereall!", "worldthereall!", "helloworldthereall!");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!");
verifyQuery("concat('hello', 'world')", "keyword", "helloworld");
verifyQuery("concat('', 'hello')", "keyword", "hello");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column
LOCATE('world', 'hello') as column
LOCATE('world', 'helloworld', 7) as column
REPLACE('helloworld', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
CONCAT('hello', 'world') as column
CONCAT('hello ', 'whole ', 'world', '!') as column