diff --git a/presto-docs/src/main/sphinx/functions/aggregate.rst b/presto-docs/src/main/sphinx/functions/aggregate.rst index b04db2091be5..1fed8619a1a4 100644 --- a/presto-docs/src/main/sphinx/functions/aggregate.rst +++ b/presto-docs/src/main/sphinx/functions/aggregate.rst @@ -590,6 +590,41 @@ Statistical Aggregate Functions Returns linear regression slope of input values. ``y`` is the dependent value. ``x`` is the independent value. +.. function:: regr_avgx(y, x) -> double + + Returns the average of the independent value in a group. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_avgy(y, x) -> double + + Returns the average of the dependent value in a group. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_count(y, x) -> double + + Returns the number of non-null pairs of input values. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_r2(y, x) -> double + + Returns the coefficient of determination of the linear regression. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_sxy(y, x) -> double + + Returns the sum of the product of the dependent and independent values in a group. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_syy(y, x) -> double + + Returns the sum of the squares of the dependent values in a group. ``y`` is the dependent + value. ``x`` is the independent value. + +.. function:: regr_sxx(y, x) -> double + + Returns the sum of the squares of the independent values in a group. ``y`` is the dependent + value. ``x`` is the independent value. + .. function:: skewness(x) -> double Returns the skewness of all input values. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 8a7700678128..57bf94696343 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -145,8 +145,10 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); + double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } public static double getRegressionSlope(RegressionState state) @@ -167,6 +169,41 @@ public static double getRegressionIntercept(RegressionState state) return meanY - slope * meanX; } + public static double getRegressionAvgy(RegressionState state) + { + return state.getMeanY(); + } + + public static double getRegressionAvgx(RegressionState state) + { + return state.getMeanX(); + } + + public static double getRegressionSxx(RegressionState state) + { + return state.getM2X(); + } + + public static double getRegressionSxy(RegressionState state) + { + return state.getC2(); + } + + public static double getRegressionSyy(RegressionState state) + { + return state.getM2Y(); + } + + public static double getRegressionR2(RegressionState state) + { + return Math.pow(state.getC2(), 2) / (state.getM2X() * state.getM2Y()); + } + + public static double getRegressionCount(RegressionState state) + { + return state.getCount(); + } + public static void mergeVarianceState(VarianceState state, VarianceState otherState) { long count = otherState.getCount(); @@ -265,6 +302,7 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index db3ad26ec5d6..e42699f2a86c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,8 +24,15 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -71,4 +78,95 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index a75222bfa93c..7b2943f8fac2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,8 +24,15 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -71,4 +78,95 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index ae3af6f46dc4..79837f90c0c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,4 +19,8 @@ public interface RegressionState double getM2X(); void setM2X(double value); + + double getM2Y(); + + void setM2Y(double value); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestDoubleRegrAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestDoubleRegrAggregationFunction.java new file mode 100644 index 000000000000..deb2f174d656 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestDoubleRegrAggregationFunction.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; + +public abstract class AbstractTestDoubleRegrAggregationFunction + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createDoubleSequenceBlock(start, start + length), createDoubleSequenceBlock(start + 2, start + 2 + length)}; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.DOUBLE, StandardTypes.DOUBLE); + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}, new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}); + testNonTrivialAggregation(new Double[] {1.0, 4.0, 9.0, 16.0, 25.0}, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}); + } + + protected abstract void testNonTrivialAggregation(Double[] y, Double[] x); +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestRealRegrAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestRealRegrAggregationFunction.java new file mode 100644 index 000000000000..81cbe269cf03 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestRealRegrAggregationFunction.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; + +public abstract class AbstractTestRealRegrAggregationFunction + extends AbstractTestAggregationFunction +{ + @Override + public Block[] getSequenceBlocks(int start, int length) + { + return new Block[] {createSequenceBlockOfReal(start, start + length), createSequenceBlockOfReal(start + 2, start + 2 + length)}; + } + + @Override + protected List getFunctionParameterTypes() + { + return ImmutableList.of(StandardTypes.REAL, StandardTypes.REAL); + } + + @Test + public void testSinglePosition() + { + testAggregation(getExpectedValue(0, 2), getSequenceBlocks(0, 2)); + } + + @Test + public void testNonTrivialResult() + { + testNonTrivialAggregation(new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}); + testNonTrivialAggregation(new Float[] {1.0f, 4.0f, 9.0f, 16.0f, 25.0f}, new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + } + + protected abstract void testNonTrivialAggregation(Float[] y, Float[] x); +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java new file mode 100644 index 000000000000..22c39514c64c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgxAggregation.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrAvgxAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_avgx"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0; + } + + double expected = 0.0; + for (int i = start; i < start + length; i++) { + expected += (i + 2); + } + return expected / length; + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + double expected = 0.0; + for (int i = 0; i < x.length; i++) { + expected += x[i]; + } + expected = expected / x.length; + checkArgument(Double.isFinite(expected) && expected != 0., "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java new file mode 100644 index 000000000000..02446f7e716e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrAvgyAggregation.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrAvgyAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_avgy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0; + } + + double expected = 0.0; + for (int i = start; i < start + length; i++) { + expected += i; + } + return expected / length; + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + double expected = 0.0; + for (int i = 0; i < y.length; i++) { + expected += y[i]; + } + expected = expected / y.length; + checkArgument(Double.isFinite(expected) && expected != 0., "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java new file mode 100644 index 000000000000..8a471cfd8c82 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrCountAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrCountAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_count"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) length; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getN(); + } + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getN(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java new file mode 100644 index 000000000000..db0337a00079 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrR2Aggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_r2"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 0) { + return null; + } + else if (length == 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getRSquare(); + } + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getRSquare(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java new file mode 100644 index 000000000000..7d07463e0603 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxxAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSxxAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_sxx"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getXSumSquares(); + } + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getXSumSquares(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java new file mode 100644 index 000000000000..d7b37ad9ac1a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSxyAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSxyAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_sxy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getSumOfCrossProducts(); + } + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getSumOfCrossProducts(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java new file mode 100644 index 000000000000..49de8dbe1d63 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrSyyAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestDoubleRegrSyyAggregation + extends AbstractTestDoubleRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_syy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (double) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (double) regression.getTotalSumSquares(); + } + } + + @Override + protected void testNonTrivialAggregation(Double[] y, Double[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + double expected = (double) regression.getTotalSumSquares(); + checkArgument(Double.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createDoublesBlock(y), createDoublesBlock(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java new file mode 100644 index 000000000000..8b2500874f1a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgxAggregation.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrAvgxAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_avgx"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0f; + } + + float expected = 0.0f; + for (int i = start; i < start + length; i++) { + expected += (i + 2); + } + return expected / length; + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + float expected = 0.0f; + for (int i = 0; i < x.length; i++) { + expected += x[i]; + } + expected = expected / x.length; + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java new file mode 100644 index 000000000000..6bb7e1ca203b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrAvgyAggregation.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrAvgyAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_avgy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length == 0) { + return 0.0f; + } + + float expected = 0.0f; + for (int i = start; i < start + length; i++) { + expected += i; + } + return expected / length; + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + float expected = 0.0f; + for (int i = 0; i < y.length; i++) { + expected += y[i]; + } + expected = expected / y.length; + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java new file mode 100644 index 000000000000..cb4f9148d33e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrCountAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrCountAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_count"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) length; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getN(); + } + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getN(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java new file mode 100644 index 000000000000..54cd9db4016d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrR2Aggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_r2"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 0) { + return null; + } + else if (length == 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getRSquare(); + } + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getRSquare(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java new file mode 100644 index 000000000000..6bf77d3b114d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxxAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSxxAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_sxx"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getXSumSquares(); + } + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getXSumSquares(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java new file mode 100644 index 000000000000..ff4a8b94cab7 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSxyAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSxyAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_sxy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getSumOfCrossProducts(); + } + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getSumOfCrossProducts(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java new file mode 100644 index 000000000000..cefe9eb54741 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrSyyAggregation.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; +import static com.google.common.base.Preconditions.checkArgument; + +public class TestRealRegrSyyAggregation + extends AbstractTestRealRegrAggregationFunction +{ + @Override + protected String getFunctionName() + { + return "regr_syy"; + } + + @Override + public Object getExpectedValue(int start, int length) + { + if (length <= 1) { + return (float) 0; + } + else { + SimpleRegression regression = new SimpleRegression(); + for (int i = start; i < start + length; i++) { + regression.addData(i + 2, i); + } + return (float) regression.getTotalSumSquares(); + } + } + + @Override + protected void testNonTrivialAggregation(Float[] y, Float[] x) + { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < x.length; i++) { + regression.addData(x[i], y[i]); + } + float expected = (float) regression.getTotalSumSquares(); + checkArgument(Float.isFinite(expected) && expected != 0.0f, "Expected result is trivial"); + testAggregation(expected, createBlockOfReals(y), createBlockOfReals(x)); + } +}