Skip to content

Commit

Permalink
Improve MvPSeriesWeightedSum edge case and add more tests (elastic#11…
Browse files Browse the repository at this point in the history
…1552)

* Update `MvPSeriesWeightedSum` function to return `null` + warnings instead of Infinite values.
* Add extra tests to `MvPSeriesWeightedSum` for edge case scenarios.

I ran the tests 100 times to ensure they didn't break due to random values.
  • Loading branch information
machadoum authored and cbuescher committed Sep 4, 2024
1 parent 3ca043e commit 532f4c8
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 25 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111552.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111552
summary: Siem ea 9521 improve test
area: ES|QL
type: enhancement
issues: []

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.elasticsearch.xpack.esql.core.type.DataType.NULL;

/**
* Reduce a multivalued field to a single valued field containing the weighted sum of all element applying the P series function.
Expand Down Expand Up @@ -89,14 +90,18 @@ protected TypeResolution resolveType() {
return resolution;
}

resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double")
.and(isNotNullAndFoldable(p, sourceText(), SECOND));

resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double");
if (resolution.unresolved()) {
return resolution;
}

return resolution;
if (p.dataType() == NULL) {
// If the type is `null` this parameter doesn't have to be foldable. It's effectively foldable anyway.
// TODO figure out if the tests are wrong here, or if null is really different from foldable null
return resolution;
}

return isFoldable(p, sourceText(), SECOND);
}

@Override
Expand Down Expand Up @@ -130,10 +135,13 @@ protected NodeInfo<? extends Expression> info() {

@Override
public DataType dataType() {
if (p.dataType() == NULL) {
return NULL;
}
return field.dataType();
}

@Evaluator(extraName = "Double")
@Evaluator(extraName = "Double", warnExceptions = ArithmeticException.class)
static void process(
DoubleBlock.Builder builder,
int position,
Expand All @@ -149,7 +157,11 @@ static void process(
double current_score = block.getDouble(i) / Math.pow(i - start + 1, p);
sum.add(current_score);
}
builder.appendDouble(sum.value());
if (Double.isFinite(sum.value())) {
builder.appendDouble(sum.value());
} else {
throw new ArithmeticException("double overflow");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ public final void testFold() {
assertTypeResolutionFailure(expression);
return;
}
assertFalse(expression.typeResolved().unresolved());
assertFalse("expected resolved", expression.typeResolved().unresolved());
Expression nullOptimized = new FoldNull().rule(expression);
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
assertTrue(nullOptimized.foldable());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.hamcrest.Matcher;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;

public class MvPSeriesWeightedSumTests extends AbstractScalarFunctionTestCase {
public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
Expand All @@ -31,10 +32,21 @@ public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier<TestCaseSupplier.Tes
@ParametersFactory
public static Iterable<Object[]> parameters() {
List<TestCaseSupplier> cases = new ArrayList<>();

doubles(cases);

// TODO use parameterSuppliersFromTypedDataWithDefaultChecks instead of parameterSuppliersFromTypedData and fix errors
cases = randomizeBytesRefsOffset(cases);
cases = anyNullIsNull(
cases,
(nullPosition, nullValueDataType, original) -> nullValueDataType == DataType.NULL ? DataType.NULL : original.expectedType(),
(nullPosition, nullData, original) -> {
if (nullData.isForceLiteral()) {
return equalTo("LiteralsEvaluator[lit=null]");
}
return nullData.type() == DataType.NULL ? equalTo("LiteralsEvaluator[lit=null]") : original;
}
);
cases = errorsForCasesWithoutExamples(cases, (valid, position) -> "double");

return parameterSuppliersFromTypedData(cases);
}

Expand All @@ -47,22 +59,51 @@ private static void doubles(List<TestCaseSupplier> cases) {
cases.add(new TestCaseSupplier("most common scenario", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
List<Double> field = randomList(1, 10, () -> randomDoubleBetween(1, 10, false));
double p = randomDoubleBetween(-10, 10, true);
double expectedResult = calcPSeriesWeightedSum(field, p);

return new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"),
new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral()
),
"MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]",
DataType.DOUBLE,
match(expectedResult)
);
return testCase(field, p);
}));

cases.add(new TestCaseSupplier("values between 0 and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
List<Double> field = randomList(1, 10, () -> randomDoubleBetween(0, 1, true));
double p = randomDoubleBetween(-10, 10, true);
return testCase(field, p);
}));

cases.add(new TestCaseSupplier("values between -1 and 0", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
List<Double> field = randomList(1, 10, () -> randomDoubleBetween(-1, 0, true));
double p = randomDoubleBetween(-10, 10, true);
return testCase(field, p);
}));

cases.add(new TestCaseSupplier("values between 1 and Double.MAX_VALUE", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
List<Double> field = randomList(1, 10, () -> randomDoubleBetween(1, Double.MAX_VALUE, true));
double p = randomDoubleBetween(-10, 10, true);
return testCase(field, p);
}));

cases.add(new TestCaseSupplier("values between -Double.MAX_VALUE and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
List<Double> field = randomList(1, 10, () -> randomDoubleBetween(-Double.MAX_VALUE, 1, true));
double p = randomDoubleBetween(-10, 10, true);
return testCase(field, p);
}));
}

private static Matcher<Double> match(Double value) {
return closeTo(value, Math.abs(value * .00000001));
private static TestCaseSupplier.TestCase testCase(List<Double> field, double p) {
double expectedResult = calcPSeriesWeightedSum(field, p);

TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"),
new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral()
),
"MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]",
DataType.DOUBLE,
Double.isFinite(expectedResult) ? closeTo(expectedResult, Math.abs(expectedResult * .00000001)) : nullValue()
);
if (Double.isFinite(expectedResult) == false) {
return testCase.withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.")
.withWarning("Line -1:-1: java.lang.ArithmeticException: double overflow");
}
return testCase;
}

private static double calcPSeriesWeightedSum(List<Double> field, double p) {
Expand Down

0 comments on commit 532f4c8

Please sign in to comment.