diff --git a/docs/changelog/116531.yaml b/docs/changelog/116531.yaml new file mode 100644 index 0000000000000..908bbff487973 --- /dev/null +++ b/docs/changelog/116531.yaml @@ -0,0 +1,5 @@ +pr: 116531 +summary: "Add a standard deviation aggregating function: STD_DEV" +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 3a27e1944a684..c2c2508ad5de2 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -17,6 +17,7 @@ The <> command supports these aggregate functions: * <> * <> * experimental:[] <> +* <> * <> * <> * <> @@ -32,6 +33,7 @@ include::layout/median_absolute_deviation.asciidoc[] include::layout/min.asciidoc[] include::layout/percentile.asciidoc[] include::layout/st_centroid_agg.asciidoc[] +include::layout/std_dev.asciidoc[] include::layout/sum.asciidoc[] include::layout/top.asciidoc[] include::layout/values.asciidoc[] diff --git a/docs/reference/esql/functions/description/std_dev.asciidoc b/docs/reference/esql/functions/description/std_dev.asciidoc new file mode 100644 index 0000000000000..b78ddd7dbba13 --- /dev/null +++ b/docs/reference/esql/functions/description/std_dev.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The standard deviation of a numeric field. diff --git a/docs/reference/esql/functions/examples/std_dev.asciidoc b/docs/reference/esql/functions/examples/std_dev.asciidoc new file mode 100644 index 0000000000000..2e6dc996aae9a --- /dev/null +++ b/docs/reference/esql/functions/examples/std_dev.asciidoc @@ -0,0 +1,22 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Examples* + +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=stdev] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=stdev-result] +|=== +The expression can use inline functions. For example, to calculate the standard deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, and then use `STD_DEV` on the result +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDevNestedExpression] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDevNestedExpression-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/std_dev.json b/docs/reference/esql/functions/kibana/definition/std_dev.json new file mode 100644 index 0000000000000..f31d3345421d9 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/std_dev.json @@ -0,0 +1,50 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "std_dev", + "description" : "The standard deviation of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "FROM employees\n| STATS STD_DEV(height)", + "FROM employees\n| STATS stddev_salary_change = STD_DEV(MV_MAX(salary_change))" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/esql/functions/kibana/docs/std_dev.md b/docs/reference/esql/functions/kibana/docs/std_dev.md new file mode 100644 index 0000000000000..a6afca7b8f6b3 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/std_dev.md @@ -0,0 +1,11 @@ + + +### STD_DEV +The standard deviation of a numeric field. + +``` +FROM employees +| STATS STD_DEV(height) +``` diff --git a/docs/reference/esql/functions/layout/std_dev.asciidoc b/docs/reference/esql/functions/layout/std_dev.asciidoc new file mode 100644 index 0000000000000..a7a34b1331d17 --- /dev/null +++ b/docs/reference/esql/functions/layout/std_dev.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-std_dev]] +=== `STD_DEV` + +*Syntax* + +[.text-center] +image::esql/functions/signature/std_dev.svg[Embedded,opts=inline] + +include::../parameters/std_dev.asciidoc[] +include::../description/std_dev.asciidoc[] +include::../types/std_dev.asciidoc[] +include::../examples/std_dev.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/std_dev.asciidoc b/docs/reference/esql/functions/parameters/std_dev.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/std_dev.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/signature/std_dev.svg b/docs/reference/esql/functions/signature/std_dev.svg new file mode 100644 index 0000000000000..606d285154f59 --- /dev/null +++ b/docs/reference/esql/functions/signature/std_dev.svg @@ -0,0 +1 @@ +STD_DEV(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/std_dev.asciidoc b/docs/reference/esql/functions/types/std_dev.asciidoc new file mode 100644 index 0000000000000..273dae4af76c2 --- /dev/null +++ b/docs/reference/esql/functions/types/std_dev.asciidoc @@ -0,0 +1,11 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | double +long | double +|=== diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 3deac4925c951..609c778df5929 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -608,6 +608,27 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" } + File stdDevAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st") + template { + it.properties = intProperties + it.inputFile = stdDevAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevIntAggregator.java" + } + template { + it.properties = longProperties + it.inputFile = stdDevAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevLongAggregator.java" + } + template { + it.properties = floatProperties + it.inputFile = stdDevAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java" + } + template { + it.properties = doubleProperties + it.inputFile = stdDevAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java" + } File topAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st") template { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java new file mode 100644 index 0000000000000..3a1185d34fa23 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * A standard deviation aggregation definition for double. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDevDoubleAggregator { + + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); + } + + public static void combine(StdDevStates.SingleState state, double value) { + state.add(value); + } + + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { + return state.evaluateFinal(driverContext); + } + + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); + } + + public static void combine(StdDevStates.GroupingState current, int groupId, double value) { + current.add(groupId, value); + } + + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { + current.combine(groupId, state.getOrNull(statePosition)); + } + + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java new file mode 100644 index 0000000000000..51c22e7e29c1e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * A standard deviation aggregation definition for float. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDevFloatAggregator { + + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); + } + + public static void combine(StdDevStates.SingleState state, float value) { + state.add(value); + } + + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { + return state.evaluateFinal(driverContext); + } + + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); + } + + public static void combine(StdDevStates.GroupingState current, int groupId, float value) { + current.add(groupId, value); + } + + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { + current.combine(groupId, state.getOrNull(statePosition)); + } + + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java new file mode 100644 index 0000000000000..24eae35cb3249 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * A standard deviation aggregation definition for int. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDevIntAggregator { + + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); + } + + public static void combine(StdDevStates.SingleState state, int value) { + state.add(value); + } + + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { + return state.evaluateFinal(driverContext); + } + + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); + } + + public static void combine(StdDevStates.GroupingState current, int groupId, int value) { + current.add(groupId, value); + } + + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { + current.combine(groupId, state.getOrNull(statePosition)); + } + + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java new file mode 100644 index 0000000000000..888ace30a0c8e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * A standard deviation aggregation definition for long. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDevLongAggregator { + + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); + } + + public static void combine(StdDevStates.SingleState state, long value) { + state.add(value); + } + + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { + return state.evaluateFinal(driverContext); + } + + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); + } + + public static void combine(StdDevStates.GroupingState current, int groupId, long value) { + current.add(groupId, value); + } + + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { + current.combine(groupId, state.getOrNull(statePosition)); + } + + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..dd6cc89401a99 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java @@ -0,0 +1,178 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDevDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDevStates.SingleState state; + + private final List channels; + + public StdDevDoubleAggregatorFunction(DriverContext driverContext, List channels, + StdDevStates.SingleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDevDoubleAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDevDoubleAggregatorFunction(driverContext, channels, StdDevDoubleAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(DoubleVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDevDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawVector(DoubleVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDevDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawBlock(DoubleBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + private void addRawBlock(DoubleBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDevDoubleAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDevDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..313eed4ae97ae --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDevDoubleAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDevDoubleAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevDoubleAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDevDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevDoubleGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_dev of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..da49c254e353a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -0,0 +1,223 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDevDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDevStates.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDevDoubleGroupingAggregatorFunction(List channels, + StdDevStates.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDevDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDevDoubleGroupingAggregatorFunction(channels, StdDevDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock valuesBlock = page.getBlock(channels.get(0)); + DoubleVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDevStates.GroupingState inState = ((StdDevDoubleGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDevDoubleAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDevDoubleAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java new file mode 100644 index 0000000000000..bf8c4854f6b93 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java @@ -0,0 +1,180 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDevFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDevStates.SingleState state; + + private final List channels; + + public StdDevFloatAggregatorFunction(DriverContext driverContext, List channels, + StdDevStates.SingleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDevFloatAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDevFloatAggregatorFunction(driverContext, channels, StdDevFloatAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + FloatBlock block = page.getBlock(channels.get(0)); + FloatVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + FloatBlock block = page.getBlock(channels.get(0)); + FloatVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(FloatVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDevFloatAggregator.combine(state, vector.getFloat(i)); + } + } + + private void addRawVector(FloatVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDevFloatAggregator.combine(state, vector.getFloat(i)); + } + } + + private void addRawBlock(FloatBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevFloatAggregator.combine(state, block.getFloat(i)); + } + } + } + + private void addRawBlock(FloatBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevFloatAggregator.combine(state, block.getFloat(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDevFloatAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDevFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..25dfa54895eda --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDevFloatAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDevFloatAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevFloatAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDevFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevFloatGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_dev of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..bf994aaf2840e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -0,0 +1,225 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDevFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDevStates.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDevFloatGroupingAggregatorFunction(List channels, + StdDevStates.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDevFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDevFloatGroupingAggregatorFunction(channels, StdDevFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock valuesBlock = page.getBlock(channels.get(0)); + FloatVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDevStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDevFloatAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDevFloatAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java new file mode 100644 index 0000000000000..4a5585a7dd454 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java @@ -0,0 +1,180 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDevIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDevStates.SingleState state; + + private final List channels; + + public StdDevIntAggregatorFunction(DriverContext driverContext, List channels, + StdDevStates.SingleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDevIntAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDevIntAggregatorFunction(driverContext, channels, StdDevIntAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDevIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawVector(IntVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDevIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevIntAggregator.combine(state, block.getInt(i)); + } + } + } + + private void addRawBlock(IntBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDevIntAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDevIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..5a762d6606a25 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDevIntAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDevIntAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevIntAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDevIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevIntGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_dev of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..139cc24d3541f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -0,0 +1,223 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDevIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDevStates.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDevIntGroupingAggregatorFunction(List channels, + StdDevStates.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDevIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDevIntGroupingAggregatorFunction(channels, StdDevIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valuesBlock = page.getBlock(channels.get(0)); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDevStates.GroupingState inState = ((StdDevIntGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDevIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDevIntAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java new file mode 100644 index 0000000000000..b5ed31116a90c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java @@ -0,0 +1,178 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDevLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDevStates.SingleState state; + + private final List channels; + + public StdDevLongAggregatorFunction(DriverContext driverContext, List channels, + StdDevStates.SingleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDevLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDevLongAggregatorFunction(driverContext, channels, StdDevLongAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(LongVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDevLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawVector(LongVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDevLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevLongAggregator.combine(state, block.getLong(i)); + } + } + } + + private void addRawBlock(LongBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDevLongAggregator.combine(state, block.getLong(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDevLongAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDevLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..09b996201ef16 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDevLongAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDevLongAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDevLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_dev of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..da7a5f4bdea0d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -0,0 +1,223 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDevLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDevLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDevStates.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDevLongGroupingAggregatorFunction(List channels, + StdDevStates.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDevLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDevLongGroupingAggregatorFunction(channels, StdDevLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDevStates.GroupingState inState = ((StdDevLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDevLongAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDevLongAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java new file mode 100644 index 0000000000000..bff8903fd3bec --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +public final class StdDevStates { + + private StdDevStates() {} + + static final class SingleState implements AggregatorState { + + private final WelfordAlgorithm welfordAlgorithm; + + SingleState() { + this(0, 0, 0); + } + + SingleState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(long value) { + welfordAlgorithm.add(value); + } + + public void add(double value) { + welfordAlgorithm.add(value); + } + + public void add(int value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(count(), 1); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + + public Block evaluateFinal(DriverContext driverContext) { + final long count = count(); + final double m2 = m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(), 1); + } + } + + static final class GroupingState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + WelfordAlgorithm getOrNull(int position) { + if (position < states.size()) { + return states.get(position); + } else { + return null; + } + } + + public void combine(int groupId, WelfordAlgorithm state) { + if (state == null) { + return; + } + combine(groupId, state.mean(), state.m2(), state.count()); + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new WelfordAlgorithm(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.add(meanValue, m2Value, countValue); + } + } + + public WelfordAlgorithm getOrSet(int groupId) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new WelfordAlgorithm(); + states.set(groupId, state); + } + return state; + } + + public void add(int groupId, long value) { + var state = getOrSet(groupId); + state.add(value); + } + + public void add(int groupId, double value) { + var state = getOrSet(groupId); + state.add(value); + } + + public void add(int groupId, int value) { + var state = getOrSet(groupId); + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var state = groupId < states.size() ? states.get(groupId) : null; + if (state != null) { + meanBuilder.appendDouble(state.mean()); + m2Builder.appendDouble(state.m2()); + countBuilder.appendLong(state.count()); + } else { + meanBuilder.appendDouble(0.0); + m2Builder.appendDouble(0.0); + countBuilder.appendLong(0); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public Block evaluateFinal(IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = getOrNull(groupId); + if (st != null) { + final var m2 = st.m2(); + final var count = st.count(); + if (count == 0 || Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluate()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java new file mode 100644 index 0000000000000..8ccb985507247 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +/** + * Algorithm for calculating standard deviation, one value at a time. + * + * @see + * Welford's_online_algorithm and + * + * Parallel algorithm + */ +public final class WelfordAlgorithm { + private double mean; + private double m2; + private long count; + + public double mean() { + return mean; + } + + public double m2() { + return m2; + } + + public long count() { + return count; + } + + public WelfordAlgorithm() { + this(0, 0, 0); + } + + public WelfordAlgorithm(double mean, double m2, long count) { + this.mean = mean; + this.m2 = m2; + this.count = count; + } + + public void add(int value) { + add((double) value); + } + + public void add(long value) { + add((double) value); + } + + public void add(double value) { + final double delta = value - mean; + count += 1; + mean += delta / count; + m2 += delta * (value - mean); + } + + public void add(double meanValue, double m2Value, long countValue) { + if (countValue == 0) { + return; + } + if (count == 0) { + mean = meanValue; + m2 = m2Value; + count = countValue; + return; + } + double delta = mean - meanValue; + m2 += m2Value + delta * delta * count * countValue / (count + countValue); + mean = (mean * count + meanValue * countValue) / (count + countValue); + count += countValue; + } + + public double evaluate() { + return count < 2 ? 0 : Math.sqrt(m2 / count); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st new file mode 100644 index 0000000000000..510d770f90d62 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * A standard deviation aggregation definition for $type$. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDev$Type$Aggregator { + + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); + } + + public static void combine(StdDevStates.SingleState state, $type$ value) { + state.add(value); + } + + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { + return state.evaluateFinal(driverContext); + } + + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); + } + + public static void combine(StdDevStates.GroupingState current, int groupId, $type$ value) { + current.add(groupId, value); + } + + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { + current.combine(groupId, state.getOrNull(statePosition)); + } + + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext); + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 859f06ed5f22e..804e5eabea949 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2883,3 +2883,143 @@ max:integer | job_positions:keyword 39878 | Business Analyst 67492 | Data Scientist ; + +stdDeviation +required_capability: std_dev +// tag::stdev[] +FROM employees +| STATS STD_DEV(height) +// end::stdev[] +; + +// tag::stdev-result[] +STD_DEV(height):double +0.20637044362020449 +// end::stdev-result[] +; + +stdDeviationNested +required_capability: std_dev +// tag::docsStatsStdDevNestedExpression[] +FROM employees +| STATS stddev_salary_change = STD_DEV(MV_MAX(salary_change)) +// end::docsStatsStdDevNestedExpression[] +; + +// tag::docsStatsStdDevNestedExpression-result[] +stddev_salary_change:double +6.875829592924112 +// end::docsStatsStdDevNestedExpression-result[] +; + + +stdDeviationWithLongs +required_capability: std_dev +FROM employees +| STATS STD_DEV(avg_worked_seconds) +; + +STD_DEV(avg_worked_seconds):double +5.76010425971634E7 +; + +stdDeviationWithInts +required_capability: std_dev +FROM employees +| STATS STD_DEV(salary) +; + +STD_DEV(salary):double +13765.12550278783 +; + +stdDeviationConstantValue +required_capability: std_dev +FROM employees +| WHERE languages == 2 +| STATS STD_DEV(languages) +; + +STD_DEV(languages):double +0.0 +; + +stdDeviationGroupedDoublesOnly +required_capability: std_dev +FROM employees +| STATS STD_DEV(height) BY languages +| SORT languages asc +; + +STD_DEV(height):double | languages:integer +0.22106409327010415 | 1 +0.22797190865484734 | 2 +0.18893070075713295 | 3 +0.14656141004227627 | 4 +0.17733860152780256 | 5 +0.2486543786061287 | null +; + +stdDeviationGroupedAllTypes +required_capability: std_dev +FROM employees +| WHERE languages < 3 +| STATS + double_std_dev = STD_DEV(height), + int_std_dev = STD_DEV(salary), + long_std_dev = STD_DEV(avg_worked_seconds) + BY languages +| SORT languages asc +; + +double_std_dev:double | int_std_dev:double | long_std_dev:double | languages:integer +0.22106409327010415 | 15166.244178730898 | 5.1998715922156096E7 | 1 +0.22797190865484734 | 12139.61099378116 | 5.309085506583288E7 | 2 +; + +stdDeviationNoRows +required_capability: std_dev +FROM employees +| WHERE languages IS null +| STATS STD_DEV(languages) +; + +STD_DEV(languages):double +null +; + +stdDevMultiValue +required_capability: std_dev +FROM employees +| STATS STD_DEV(salary_change) +; + +STD_DEV(salary_change):double +7.062226788733394 +; + +stdDevFilter +required_capability: std_dev +FROM employees +| STATS greater_than = STD_DEV(salary_change) WHERE languages > 3 +, less_than = STD_DEV(salary_change) WHERE languages <= 3 +, salary = STD_DEV(salary * 2) +, count = COUNT(*) BY gender +| SORT gender asc +; + +greater_than:double | less_than:double | salary:double | count:long | gender:keyword +6.4543266953142835 | 7.57786788789264 | 29045.770666969744 | 33 | F +6.975232333891946 | 6.604807075547775 | 26171.331109641273 | 57 | M +6.949207097931448 | 7.127229475750027 | 27921.220736207077 | 10 | null +; + +stdDevRow +required_capability: std_dev +ROW a = [1,2,3], b = 5 +| STATS STD_DEV(a), STD_DEV(b) +; + +STD_DEV(a):double | STD_DEV(b):double +0.816496580927726 | 0.0 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 82942119dead4..7aa2e782bb7f8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -440,6 +440,11 @@ public enum Cap { */ PER_AGG_FILTERING_ORDS, + /** + * Support for {@code STD_DEV} aggregation. + */ + STD_DEV, + /** * Fix for https://github.com/elastic/elasticsearch/issues/114714 */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index eafb1fdbcbdcb..ea1669ccc7a4f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; @@ -276,6 +277,7 @@ private static FunctionDefinition[][] functions() { def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"), def(Min.class, uni(Min::new), "min"), def(Percentile.class, bi(Percentile::new), "percentile"), + def(StdDev.class, uni(StdDev::new), "std_dev"), def(Sum.class, uni(Sum::new), "sum"), def(Top.class, tri(Top::new), "top"), def(Values.class, uni(Values::new), "values"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index b9cfd8892dd69..d74b5c8b386b8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -25,6 +25,7 @@ public static List getNamedWriteables() { Percentile.ENTRY, Rate.ENTRY, SpatialCentroid.ENTRY, + StdDev.ENTRY, Sum.ENTRY, Top.ENTRY, Values.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java new file mode 100644 index 0000000000000..189b6a81912cb --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +public class StdDev extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StdDev", StdDev::new); + + @FunctionInfo( + returnType = "double", + description = "The standard deviation of a numeric field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "stdev"), + @Example( + description = "The expression can use inline functions. For example, to calculate the standard " + + "deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, " + + "and then use `STD_DEV` on the result", + file = "stats", + tag = "docsStatsStdDevNestedExpression" + ) } + ) + public StdDev(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { + this(source, field, Literal.TRUE); + } + + public StdDev(Source source, Expression field, Expression filter) { + super(source, field, filter, emptyList()); + } + + private StdDev(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + protected Expression.TypeResolution resolveType() { + return isType( + field(), + dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, + sourceText(), + DEFAULT, + "numeric except unsigned_long or counter types" + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, StdDev::new, field(), filter()); + } + + @Override + public StdDev replaceChildren(List newChildren) { + return new StdDev(source(), newChildren.get(0), newChildren.get(1)); + } + + public StdDev withFilter(Expression filter) { + return new StdDev(source(), field(), filter); + } + + @Override + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.LONG) { + return new StdDevLongAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.INTEGER) { + return new StdDevIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new StdDevDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 9bb0ab4144bed..41a6a17a50dcb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; @@ -48,9 +49,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; -import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; - /** * Static class used to convert aggregate expressions to the named expressions that represent their intermediate state. *

@@ -78,6 +76,7 @@ final class AggregateMapper { Min.class, Percentile.class, SpatialCentroid.class, + StdDev.class, Sum.class, Values.class, Top.class, @@ -171,7 +170,7 @@ private static Stream, Tuple>> typeAndNames(Class types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); } else if (Top.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); - } else if (Rate.class.isAssignableFrom(clazz)) { + } else if (Rate.class.isAssignableFrom(clazz) || StdDev.class.isAssignableFrom(clazz)) { types = List.of("Int", "Long", "Double"); } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) { types = List.of(""); // no type diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java new file mode 100644 index 0000000000000..85b96e29d1f6a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.compute.aggregation.WelfordAlgorithm; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class StdDevTests extends AbstractAggregationTestCase { + public StdDevTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + ).flatMap(List::stream).map(StdDevTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new StdDev(source, args.get(0)); + } + + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var fieldValues = fieldTypedData.multiRowData(); + + WelfordAlgorithm welfordAlgorithm = new WelfordAlgorithm(); + + for (var fieldValue : fieldValues) { + var value = ((Number) fieldValue).doubleValue(); + welfordAlgorithm.add(value); + } + var result = welfordAlgorithm.evaluate(); + var expected = Double.isInfinite(result) ? null : result; + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "StdDev[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 4c3b16c5dc309..72c7c51655378 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -92,7 +92,7 @@ setup: - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} # Testing for the entire function set isn't feasbile, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 120} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 121} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": @@ -163,4 +163,4 @@ setup: - match: {esql.functions.cos: $functions_cos} - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - - length: {esql.functions: 117} # check the "sister" test above for a likely update to the same esql.functions length check + - length: {esql.functions: 118} # check the "sister" test above for a likely update to the same esql.functions length check