diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 8a7700678128..57bf94696343 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -145,8 +145,10 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); + double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } public static double getRegressionSlope(RegressionState state) @@ -167,6 +169,41 @@ public static double getRegressionIntercept(RegressionState state) return meanY - slope * meanX; } + public static double getRegressionAvgy(RegressionState state) + { + return state.getMeanY(); + } + + public static double getRegressionAvgx(RegressionState state) + { + return state.getMeanX(); + } + + public static double getRegressionSxx(RegressionState state) + { + return state.getM2X(); + } + + public static double getRegressionSxy(RegressionState state) + { + return state.getC2(); + } + + public static double getRegressionSyy(RegressionState state) + { + return state.getM2Y(); + } + + public static double getRegressionR2(RegressionState state) + { + return Math.pow(state.getC2(), 2) / (state.getM2X() * state.getM2Y()); + } + + public static double getRegressionCount(RegressionState state) + { + return state.getCount(); + } + public static void mergeVarianceState(VarianceState state, VarianceState otherState) { long count = otherState.getCount(); @@ -265,6 +302,7 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index db3ad26ec5d6..e42699f2a86c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,8 +24,15 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -71,4 +78,95 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index a75222bfa93c..7b2943f8fac2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,8 +24,15 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -71,4 +78,95 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index ae3af6f46dc4..79837f90c0c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,4 +19,8 @@ public interface RegressionState double getM2X(); void setM2X(double value); + + double getM2Y(); + + void setM2Y(double value); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java new file mode 100644 index 000000000000..0687e3db9d98 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrAvgxAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_avgx"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0; + } + + double expected = 0.0; + for (int i = start; i < start + length; i++) { + expected += (i + 2); + } + return expected / length; + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + double expected = 0.0; + for (int i = 0; i < x.length; i++) { + expected += x[i]; + } + expected = expected / x.length; + checkArgument(Double.isFinite(expected) && expected != 0., "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java new file mode 100644 index 000000000000..b5a531f78aac --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrAvgyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_avgy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0; + } + + double expected = 0.0; + for (int i = start; i < start + length; i++) { + expected += i; + } + return expected / length; + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + double expected = 0.0; + for (int i = 0; i < y.length; i++) { + expected += y[i]; + } + expected = expected / y.length; + checkArgument(Double.isFinite(expected) && expected != 0., "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java new file mode 100644 index 000000000000..012b7b0608c5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrCountAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_count"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) length; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getN(); + } + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getN(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java new file mode 100644 index 000000000000..7e6de6e0ef0a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrR2Aggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_r2"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 0) { + return null; + } + else if (length == 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getRSquare(); + } + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getRSquare(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java new file mode 100644 index 000000000000..904e7c23cb30 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSxxAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_sxx"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getXSumSquares(); + } + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getXSumSquares(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java new file mode 100644 index 000000000000..31b813571f73 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSxyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_sxy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getSumOfCrossProducts(); + } + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getSumOfCrossProducts(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java new file mode 100644 index 000000000000..c05686e37ad9 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSyyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_syy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getTotalSumSquares(); + } + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + private void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getTotalSumSquares(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java new file mode 100644 index 000000000000..f0932737eb03 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrAvgxAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_avgx"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0f; + } + + float expected = 0.0f; + for (int i = start; i < start + length; i++) { + expected += (i + 2); + } + return expected / length; + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + float expected = 0.0f; + for (int i = 0; i < x.length; i++) { + expected += x[i]; + } + expected = expected / x.length; + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java new file mode 100644 index 000000000000..cf936d659a0a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrAvgyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_avgy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0f; + } + + float expected = 0.0f; + for (int i = start; i < start + length; i++) { + expected += i; + } + return expected / length; + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + float expected = 0.0f; + for (int i = 0; i < y.length; i++) { + expected += y[i]; + } + expected = expected / y.length; + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java new file mode 100644 index 000000000000..4d20dada35c6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrCountAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_count"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) length; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getN(); + } + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getN(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java new file mode 100644 index 000000000000..b41c01765e41 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrR2Aggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_r2"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 0) { + return null; + } + else if (length == 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getRSquare(); + } + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getRSquare(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java new file mode 100644 index 000000000000..8566d83ba1fe --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSxxAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_sxx"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getXSumSquares(); + } + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getXSumSquares(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java new file mode 100644 index 000000000000..f3b9eab48ec5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSxyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_sxy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getSumOfCrossProducts(); + } + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getSumOfCrossProducts(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java new file mode 100644 index 000000000000..163200798b63 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSyyAggregation + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected String getFunctionName() + { + return "regr_syy"; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getTotalSumSquares(); + } + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + private void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getTotalSumSquares(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +}