Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Enable concat() string function to support multiple string arguments #1297

Merged
merged 1 commit into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private FunctionBuilder getFunctionBuilder(
List<ExprType> sourceTypes = functionSignature.getParamTypeList();
List<ExprType> targetTypes = resolvedSignature.getKey().getParamTypeList();
FunctionBuilder funcBuilder = resolvedSignature.getValue();
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
if (isCastFunction(functionName)
|| FunctionSignature.isVarArgFunction(targetTypes)
|| sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
return castArguments(sourceTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unreso
functionSignature));
}
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) {
if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())
&& (unresolvedSignature.getParamTypeList().isEmpty()
|| unresolvedSignature.getParamTypeList().size() > 9)) {
throw new ExpressionEvaluationException(
String.format("%s function expected 1-9 arguments, but got %d",
functionName, unresolvedSignature.getParamTypeList().size()));
}
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())
&& !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) {
throw new ExpressionEvaluationException(
String.format("%s function expected %s, but get %s", functionName,
formatFunctions(functionBundle.keySet()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;

import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
Expand Down Expand Up @@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) {
|| paramTypeList.size() != functionTypeList.size()) {
return NOT_MATCH;
}
// TODO: improve to support regular and array type mixed, ex. func(int,string,array)
if (isVarArgFunction(functionTypeList)) {
return EXACTLY_MATCH;
}

int matchDegree = EXACTLY_MATCH;
for (int i = 0; i < paramTypeList.size(); i++) {
Expand All @@ -62,4 +68,11 @@ public String formatTypes() {
.map(ExprType::typeName)
.collect(Collectors.joining(",", "[", "]"));
}

/**
* util function - returns true if function has variable arguments.
*/
protected static boolean isVarArgFunction(List<ExprType> argTypes) {
return argTypes.size() == 1 && argTypes.get(0) == ARRAY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,34 @@

package org.opensearch.sql.expression.text;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.define;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.FunctionSignature;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.expression.function.SerializableTriFunction;


/**
* The definition of text functions.
* 1) have the clear interface for function define.
Expand Down Expand Up @@ -141,16 +151,37 @@ private DefaultFunctionResolver upper() {
}

/**
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
* Extend to accept variable argument amounts.
* Concatenates a list of Strings.
* Supports following signatures:
* (STRING, STRING) -> STRING
* (STRING, STRING, ...., STRING) -> STRING
*/
private DefaultFunctionResolver concat() {
return define(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling((str1, str2) ->
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName();
return define(concatFuncName, funcName ->
Pair.of(
new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)),
(funcProp, args) -> new FunctionExpression(funcName, args) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
List<ExprValue> exprValues = args.stream()
.map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList());
if (exprValues.stream().anyMatch(ExprValue::isMissing)) {
return ExprValueUtils.missingValue();
}
if (exprValues.stream().anyMatch(ExprValue::isNull)) {
return ExprValueUtils.nullValue();
}
return new ExprStringValue(exprValues.stream()
.map(ExprValue::stringValue)
.collect(Collectors.joining()));
}

@Override
public ExprType type() {
return STRING;
}
}
));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -76,4 +80,53 @@ void resolve_function_not_match() {
assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]",
exception.getMessage());
}

@Test
void resolve_varargs_function_signature_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING));
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue());
}

@Test
void resolve_varargs_no_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with no arguments
when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList());

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 0",
exception.getMessage());
}

@Test
void resolve_varargs_too_many_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with more than 9 arguments
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList
.of(STRING, STRING, STRING, STRING, STRING,
STRING, STRING, STRING, STRING, STRING));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 10",
exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase {
private static List<List<String>> CONCAT_STRING_LISTS = ImmutableList.of(
ImmutableList.of("hello", "world"),
ImmutableList.of("123", "5325"));
private static List<List<String>> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of(
ImmutableList.of("he", "llo", "wo", "rld", "!"),
ImmutableList.of("0", "123", "53", "25", "7"));

interface SubstrSubstring {
FunctionExpression getFunction(SubstringInfo strInfo);
Expand Down Expand Up @@ -228,11 +231,13 @@ public void upper() {
@Test
void concat() {
CONCAT_STRING_LISTS.forEach(this::testConcatString);
CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString);

when(nullRef.type()).thenReturn(STRING);
when(missingRef.type()).thenReturn(STRING);
assertEquals(missingValue(), eval(
DSL.concat(missingRef, DSL.literal("1"))));
// If any of the expressions is a NULL value, it returns NULL.
assertEquals(nullValue(), eval(
DSL.concat(nullRef, DSL.literal("1"))));
assertEquals(missingValue(), eval(
Expand Down Expand Up @@ -446,6 +451,22 @@ void testConcatString(List<String> strings, String delim) {
assertEquals(expected, eval(expression).stringValue());
}

void testConcatMultipleString(List<String> strings) {
String expected = null;
if (strings.stream().noneMatch(Objects::isNull)) {
expected = String.join("", strings);
}

FunctionExpression expression = DSL.concat(
DSL.literal(strings.get(0)),
DSL.literal(strings.get(1)),
DSL.literal(strings.get(2)),
DSL.literal(strings.get(3)),
DSL.literal(strings.get(4)));
assertEquals(STRING, expression.type());
assertEquals(expected, eval(expression).stringValue());
}

void testLengthString(String str) {
FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str)));
assertEquals(INTEGER, expression.type());
Expand Down
16 changes: 8 additions & 8 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2614,21 +2614,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> SELECT CONCAT('hello', 'world')
os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null)
fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+--------------------------------------------+----------------------------+-------------------------+
| CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) |
|--------------------------------------------+----------------------------+-------------------------|
| hello whole world! | helloworld | null |
+--------------------------------------------+----------------------------+-------------------------+


CONCAT_WS
Expand Down
16 changes: 8 additions & 8 deletions docs/user/ppl/functions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')`
os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')`
fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+----------------------------+--------------------------------------------+
| CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') |
|----------------------------+--------------------------------------------|
| helloworld | hello whole world! |
+----------------------------+--------------------------------------------+


CONCAT_WS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat", "", ", 'there'",
"hellothere", "worldthere", "helloworldthere");
verifyQuery("concat", "", ", 'there', 'all', '!'",
"hellothereall!", "worldthereall!", "helloworldthereall!");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!");
verifyQuery("concat('hello', 'world')", "keyword", "helloworld");
verifyQuery("concat('', 'hello')", "keyword", "hello");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column
LOCATE('world', 'hello') as column
LOCATE('world', 'helloworld', 7) as column
REPLACE('helloworld', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
CONCAT('hello', 'world') as column
CONCAT('hello ', 'whole ', 'world', '!') as column