Skip to content

Commit

Permalink
Enable concat() string function to support multiple string arguments (
Browse files Browse the repository at this point in the history
#200)

* Refactor concat() to support multiple string arguments

Signed-off-by: Margarit Hakobyan <[email protected]>
  • Loading branch information
margarit-h committed Jan 13, 2023
1 parent c6a59f7 commit 1f924f5
Show file tree
Hide file tree
Showing 15 changed files with 399 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ private FunctionBuilder getFunctionBuilder(
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
// For functions with variable number of args (ex: concat())
// targetTypes will always be empty (as the function signature is not fixed),
// and failure will occur.
// So, in this case sourceTypes are passed instead of targetTypes to address that.
if (functionResolverMap.get(functionName) instanceof VarargsFunctionResolver) {
return castArguments(sourceTypes, sourceTypes, funcBuilder);
}
return castArguments(sourceTypes,
targetTypes, funcBuilder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
Expand Down Expand Up @@ -58,6 +60,39 @@ public static DefaultFunctionResolver define(FunctionName functionName, List<
return builder.build();
}

/**
* Define varargs function with implementation.
*
* @param functionName function name.
* @param functions a list of function implementation.
* @return VarargsFunctionResolver.
*/
public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName,
SerializableFunction<FunctionName,
Pair<FunctionSignature, FunctionBuilder>>... functions) {
return defineVarargsFunction(functionName, List.of(functions));
}

/**
* Define varargs function with implementation.
*
* @param functionName function name.
* @param functions a list of function implementation.
* @return VarargsFunctionResolver.
*/
public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, List<
SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>>> functions) {

VarargsFunctionResolver.VarargsFunctionResolverBuilder builder =
VarargsFunctionResolver.builder();
builder.functionName(functionName);
for (SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>> func
: functions) {
Pair<FunctionSignature, FunctionBuilder> functionBuilder = func.apply(functionName);
builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue());
}
return builder.build();
}

/**
* Implementation of no args function that uses FunctionProperties.
Expand Down Expand Up @@ -212,6 +247,56 @@ public static SerializableFunction<FunctionName, Pair<FunctionSignature, Functio
return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType);
}

/**
* Varargs Function Implementation.
* This implementation considers 1...n args of the same type.
*
* @param function {@link ExprValue} based varargs function.
* @param returnType return type.
* @param argsType argument type.
* @return Varargs Function Implementation.
*/
public static SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>> impl(
SerializableVarargsFunction<ExprValue, ExprValue> function,
ExprType returnType,
ExprType argsType,
boolean withVarargs) {

return functionName -> {
AtomicInteger argsCount = new AtomicInteger(0);
FunctionBuilder functionBuilder =
(functionProperties, arguments) -> new FunctionExpression(functionName, arguments) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
argsCount.set(arguments.size());
ExprValue[] args = arguments.stream()
.map(arg -> arg.valueOf(valueEnv))
.collect(Collectors.toList())
.toArray(new ExprValue[arguments.size()]);

return function.apply(args);
}

@Override
public ExprType type() {
return returnType;
}

@Override
public String toString() {
return String.format("%s(%s)", functionName, arguments.stream()
.map(Object::toString)
.collect(Collectors.joining(", ")));
}
};
ExprCoreType[] argsTypes = new ExprCoreType[argsCount.get()];
Arrays.fill(argsTypes, argsType);
FunctionSignature functionSignature =
new FunctionSignature(functionName, List.of(argsTypes));
return Pair.of(functionSignature, functionBuilder);
};
}

/**
* Binary Function Implementation.
*
Expand Down Expand Up @@ -323,13 +408,29 @@ public SerializableTriFunction<ExprValue, ExprValue, ExprValue, ExprValue> nullM
};
}

/**
* Wrapper the varargs ExprValue function with default NULL and MISSING handling.
*/
public SerializableVarargsFunction<ExprValue, ExprValue> nullMissingHandling(
SerializableVarargsFunction<ExprValue, ExprValue> function, boolean withVarargs) {
return (args) -> {
if (Arrays.stream(args).anyMatch(ExprValue::isMissing)) {
return ExprValueUtils.missingValue();
} else if (Arrays.stream(args).anyMatch(ExprValue::isNull)) {
return ExprValueUtils.nullValue();
} else {
return function.apply(args);
}
};
}

/**
* Wrapper the unary ExprValue function that is aware of FunctionProperties,
* with default NULL and MISSING handling.
*/
public static SerializableBiFunction<FunctionProperties, ExprValue, ExprValue>
nullMissingHandlingWithProperties(
SerializableBiFunction<FunctionProperties, ExprValue, ExprValue> implementation) {
SerializableBiFunction<FunctionProperties, ExprValue, ExprValue> implementation) {
return (functionProperties, v1) -> {
if (v1.isMissing()) {
return ExprValueUtils.missingValue();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.expression.function;

import java.io.Serializable;

/**
* Serializable Varargs Function.
*/
public interface SerializableVarargsFunction<T, R> extends Serializable {
/**
* Applies this function to the given arguments.
*
* @param t the function argument
* @return the function result
*/
R apply(T... t);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import java.util.AbstractMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Singular;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.exception.ExpressionEvaluationException;

/**
* The Function Resolver hold the overload {@link FunctionBuilder} implementation.
* is composed by {@link FunctionName} which identified the function name
* and a map of {@link FunctionSignature} and {@link FunctionBuilder}
* to represent the overloaded implementation
*/
@Builder
@RequiredArgsConstructor
public class VarargsFunctionResolver implements FunctionResolver {
@Getter
private final FunctionName functionName;
@Singular("functionBundle")
private final Map<FunctionSignature, FunctionBuilder> functionBundle;

/**
* Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}.
* If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it.
* If applying the widening rule, found the most match one, return it.
* If nothing found, throw {@link ExpressionEvaluationException}
*
* @return function signature and its builder
*/
@Override
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
PriorityQueue<Map.Entry<Integer, FunctionSignature>> functionMatchQueue = new PriorityQueue<>(
Map.Entry.comparingByKey());

for (FunctionSignature functionSignature : functionBundle.keySet()) {
functionMatchQueue.add(
new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature),
functionSignature));
}
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
if (unresolvedSignature.getParamTypeList().isEmpty()) {
throw new ExpressionEvaluationException(
String.format("%s function expected %s, but get %s", functionName,
formatFunctions(functionBundle.keySet()),
unresolvedSignature.formatTypes()
));
} else {
FunctionSignature resolvedSignature = bestMatchEntry.getValue();
return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature));
}
}

private String formatFunctions(Set<FunctionSignature> functionSignatures) {
return functionSignatures.stream().map(FunctionSignature::formatTypes)
.collect(Collectors.joining(",", "{", "}"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.define;
import static org.opensearch.sql.expression.function.FunctionDSL.defineVarargsFunction;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import java.util.Arrays;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprStringValue;
Expand All @@ -22,7 +25,7 @@
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.expression.function.SerializableTriFunction;

import org.opensearch.sql.expression.function.VarargsFunctionResolver;

/**
* The definition of text functions.
Expand Down Expand Up @@ -141,16 +144,16 @@ private DefaultFunctionResolver upper() {
}

/**
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
* Extend to accept variable argument amounts.
* Concatenates a list of Strings.
* Supports following signatures:
* (STRING, STRING) -> STRING
* (STRING, STRING, ...., STRING) -> STRING
*/
private DefaultFunctionResolver concat() {
return define(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling((str1, str2) ->
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
private VarargsFunctionResolver concat() {
return defineVarargsFunction(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling(strings ->
new ExprStringValue(Arrays.stream(strings)
.map(ExprValue::stringValue)
.collect(Collectors.joining())), true), STRING, STRING, true));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,37 @@ void resolve_unregistered() {
assertEquals("unsupported function name: unknown", exception.getMessage());
}

@Test
void resolve_should_cast_arguments_for_varargs_function() {
FunctionSignature unresolvedSignature = new FunctionSignature(
mockFunctionName, ImmutableList.of(STRING, STRING, STRING));
FunctionSignature resolvedSignature = new FunctionSignature(
mockFunctionName, Collections.emptyList());

VarargsFunctionResolver varargsFunctionResolver = mock(VarargsFunctionResolver.class);
FunctionBuilder funcBuilder = mock(FunctionBuilder.class);

when(mockFunctionName.getFunctionName()).thenReturn("mockFunction");
when(mockExpression.toString()).thenReturn("string");
when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap);
when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true);
when(mockMap.containsKey(eq(mockFunctionName))).thenReturn(true);
when(mockMap.get(eq(mockFunctionName))).thenReturn(varargsFunctionResolver);
when(varargsFunctionResolver.resolve(eq(unresolvedSignature))).thenReturn(
Pair.of(resolvedSignature, funcBuilder));
repo.register(varargsFunctionResolver);
// Relax unnecessary stubbing check because error case test doesn't call this
lenient().doAnswer(invocation ->
new FakeFunctionExpression(mockFunctionName, invocation.getArgument(1))
).when(funcBuilder).apply(eq(functionProperties), any());

FunctionImplementation function =
repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), unresolvedSignature)
.apply(functionProperties,
ImmutableList.of(mockExpression, mockExpression, mockExpression));
assertEquals("mockFunction(string, string, string)", function.toString());
}

private FunctionSignature registerFunctionResolver(FunctionName funcName,
ExprType sourceType,
ExprType targetType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public int compareTo(ExprValue o) {
twoArgs = (v1, v2) -> ANY;
static final SerializableTriFunction<ExprValue, ExprValue, ExprValue, ExprValue>
threeArgs = (v1, v2, v3) -> ANY;
static final SerializableVarargsFunction<ExprValue, ExprValue>
varrgs = (v1) -> ANY;
@Mock
FunctionProperties mockProperties;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.expression.function.FunctionDSL.impl;

import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;

class FunctionDSLimplVarargsTest extends FunctionDSLimplTestBase {

@Override
SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>>
getImplementationGenerator() {
return impl(varrgs, ANY_TYPE, ANY_TYPE, true);
}

@Override
List<Expression> getSampleArguments() {
return List.of(DSL.literal(ANY));
}

@Override
String getExpected_toString() {
return "sample(ANY)";
}
}
Loading

0 comments on commit 1f924f5

Please sign in to comment.