From 048932f5928569c6ce8940fb414d7cabc08cc6e3 Mon Sep 17 00:00:00 2001 From: machadoum Date: Fri, 2 Aug 2024 09:28:14 +0200 Subject: [PATCH] wip --- .../multivalue/MvPSeriesWeightedSumTests.java | 70 ++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java index 0f277485b874d..95175e9d871e3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java @@ -22,6 +22,7 @@ import java.util.function.Supplier; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.is; public class MvPSeriesWeightedSumTests extends AbstractScalarFunctionTestCase { public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier testCaseSupplier) { @@ -59,10 +60,77 @@ private static void doubles(List cases) { match(expectedResult) ); })); + + cases.add(new TestCaseSupplier("values between 0 and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(0, 1, true)); + double p = randomDoubleBetween(-10, 10, true); + double expectedResult = calcPSeriesWeightedSum(field, p); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + match(expectedResult) + ); + })); + + cases.add(new TestCaseSupplier("values between -1 and 0", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(-1, 0, true)); + double p = randomDoubleBetween(-10, 10, true); + double expectedResult = calcPSeriesWeightedSum(field, p); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + match(expectedResult) + ); + })); + + cases.add(new TestCaseSupplier("values between 1 and Double.MAX_VALUE", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(1, Double.MAX_VALUE, true)); + double p = randomDoubleBetween(-10, 10, true); + double expectedResult = calcPSeriesWeightedSum(field, p); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + match(expectedResult) + ); + })); + + cases.add(new TestCaseSupplier("values between -Double.MAX_VALUE and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(-Double.MAX_VALUE, 1, true)); + double p = randomDoubleBetween(-10, 10, true); + double expectedResult = calcPSeriesWeightedSum(field, p); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + match(expectedResult) + ); + })); } private static Matcher match(Double value) { - return closeTo(value, Math.abs(value * .00000001)); + if (Double.isFinite(value)) { + return closeTo(value, Math.abs(value * .00000001)); + } + return is(value); } private static double calcPSeriesWeightedSum(List field, double p) {