diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f4c89e58fa431..c60088d58cf58 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -382,6 +382,52 @@ For example: +## Aggregations + +The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common +aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc. +While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in +[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and +[Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets. +Moreover, users are not limited to the predefined aggregate functions and can create their own. + +### Untyped User-Defined Aggregate Functions + +
+ +
+ +Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) +abstract class to implement a custom untyped aggregate function. For example, a user-defined average +can look like: + +{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%} +
+ +
+ +{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%} +
+ +
+ +### Type-Safe User-Defined Aggregate Functions + +User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class. +For example, a type-safe user-defined average can look like: +
+ +
+ +{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%} +
+ +
+ +{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%} +
+ +
# Data Sources diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java new file mode 100644 index 0000000000000..78e9011be4705 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:typed_custom_aggregation$ +import java.io.Serializable; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.expressions.Aggregator; +// $example off:typed_custom_aggregation$ + +public class JavaUserDefinedTypedAggregation { + + // $example on:typed_custom_aggregation$ + public static class Employee implements Serializable { + private String name; + private long salary; + + // Constructors, getters, setters... + // $example off:typed_custom_aggregation$ + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public long getSalary() { + return salary; + } + + public void setSalary(long salary) { + this.salary = salary; + } + // $example on:typed_custom_aggregation$ + } + + public static class Average implements Serializable { + private long sum; + private long count; + + // Constructors, getters, setters... + // $example off:typed_custom_aggregation$ + public Average() { + } + + public Average(long sum, long count) { + this.sum = sum; + this.count = count; + } + + public long getSum() { + return sum; + } + + public void setSum(long sum) { + this.sum = sum; + } + + public long getCount() { + return count; + } + + public void setCount(long count) { + this.count = count; + } + // $example on:typed_custom_aggregation$ + } + + public static class MyAverage extends Aggregator { + // A zero value for this aggregation. Should satisfy the property that any b + zero = b + public Average zero() { + return new Average(0L, 0L); + } + // Combine two values to produce a new value. For performance, the function may modify `buffer` + // and return it instead of constructing a new object + public Average reduce(Average buffer, Employee employee) { + long newSum = buffer.getSum() + employee.getSalary(); + long newCount = buffer.getCount() + 1; + buffer.setSum(newSum); + buffer.setCount(newCount); + return buffer; + } + // Merge two intermediate values + public Average merge(Average b1, Average b2) { + long mergedSum = b1.getSum() + b2.getSum(); + long mergedCount = b1.getCount() + b2.getCount(); + b1.setSum(mergedSum); + b1.setCount(mergedCount); + return b1; + } + // Transform the output of the reduction + public Double finish(Average reduction) { + return ((double) reduction.getSum()) / reduction.getCount(); + } + // Specifies the Encoder for the intermediate value type + public Encoder bufferEncoder() { + return Encoders.bean(Average.class); + } + // Specifies the Encoder for the final output value type + public Encoder outputEncoder() { + return Encoders.DOUBLE(); + } + } + // $example off:typed_custom_aggregation$ + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL user-defined Datasets aggregation example") + .getOrCreate(); + + // $example on:typed_custom_aggregation$ + Encoder employeeEncoder = Encoders.bean(Employee.class); + String path = "examples/src/main/resources/employees.json"; + Dataset ds = spark.read().json(path).as(employeeEncoder); + ds.show(); + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + MyAverage myAverage = new MyAverage(); + // Convert the function to a `TypedColumn` and give it a name + TypedColumn averageSalary = myAverage.toColumn().name("average_salary"); + Dataset result = ds.select(averageSalary); + result.show(); + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:typed_custom_aggregation$ + spark.stop(); + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java new file mode 100644 index 0000000000000..6da60a1fc6b88 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:untyped_custom_aggregation$ +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off:untyped_custom_aggregation$ + +public class JavaUserDefinedUntypedAggregation { + + // $example on:untyped_custom_aggregation$ + public static class MyAverage extends UserDefinedAggregateFunction { + + private StructType inputSchema; + private StructType bufferSchema; + + public MyAverage() { + List inputFields = new ArrayList<>(); + inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); + inputSchema = DataTypes.createStructType(inputFields); + + List bufferFields = new ArrayList<>(); + bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); + bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); + bufferSchema = DataTypes.createStructType(bufferFields); + } + // Data types of input arguments of this aggregate function + public StructType inputSchema() { + return inputSchema; + } + // Data types of values in the aggregation buffer + public StructType bufferSchema() { + return bufferSchema; + } + // The data type of the returned value + public DataType dataType() { + return DataTypes.DoubleType; + } + // Whether this function always returns the same output on the identical input + public boolean deterministic() { + return true; + } + // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to + // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides + // the opportunity to update its values. Note that arrays and maps inside the buffer are still + // immutable. + public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, 0L); + buffer.update(1, 0L); + } + // Updates the given aggregation buffer `buffer` with new input data from `input` + public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + long updatedSum = buffer.getLong(0) + input.getLong(0); + long updatedCount = buffer.getLong(1) + 1; + buffer.update(0, updatedSum); + buffer.update(1, updatedCount); + } + } + // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` + public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); + long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); + buffer1.update(0, mergedSum); + buffer1.update(1, mergedCount); + } + // Calculates the final result + public Double evaluate(Row buffer) { + return ((double) buffer.getLong(0)) / buffer.getLong(1); + } + } + // $example off:untyped_custom_aggregation$ + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL user-defined DataFrames aggregation example") + .getOrCreate(); + + // $example on:untyped_custom_aggregation$ + // Register the function to access it + spark.udf().register("myAverage", new MyAverage()); + + Dataset df = spark.read().json("examples/src/main/resources/employees.json"); + df.createOrReplaceTempView("employees"); + df.show(); + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + Dataset result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees"); + result.show(); + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:untyped_custom_aggregation$ + + spark.stop(); + } +} diff --git a/examples/src/main/resources/employees.json b/examples/src/main/resources/employees.json new file mode 100644 index 0000000000000..6b2e6329a1cb2 --- /dev/null +++ b/examples/src/main/resources/employees.json @@ -0,0 +1,4 @@ +{"name":"Michael", "salary":3000} +{"name":"Andy", "salary":4500} +{"name":"Justin", "salary":3500} +{"name":"Berta", "salary":4000} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala new file mode 100644 index 0000000000000..ac617d19d36cf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +// $example on:typed_custom_aggregation$ +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.SparkSession +// $example off:typed_custom_aggregation$ + +object UserDefinedTypedAggregation { + + // $example on:typed_custom_aggregation$ + case class Employee(name: String, salary: Long) + case class Average(var sum: Long, var count: Long) + + object MyAverage extends Aggregator[Employee, Average, Double] { + // A zero value for this aggregation. Should satisfy the property that any b + zero = b + def zero: Average = Average(0L, 0L) + // Combine two values to produce a new value. For performance, the function may modify `buffer` + // and return it instead of constructing a new object + def reduce(buffer: Average, employee: Employee): Average = { + buffer.sum += employee.salary + buffer.count += 1 + buffer + } + // Merge two intermediate values + def merge(b1: Average, b2: Average): Average = { + b1.sum += b2.sum + b1.count += b2.count + b1 + } + // Transform the output of the reduction + def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count + // Specifies the Encoder for the intermediate value type + def bufferEncoder: Encoder[Average] = Encoders.product + // Specifies the Encoder for the final output value type + def outputEncoder: Encoder[Double] = Encoders.scalaDouble + } + // $example off:typed_custom_aggregation$ + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("Spark SQL user-defined Datasets aggregation example") + .getOrCreate() + + import spark.implicits._ + + // $example on:typed_custom_aggregation$ + val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee] + ds.show() + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + // Convert the function to a `TypedColumn` and give it a name + val averageSalary = MyAverage.toColumn.name("average_salary") + val result = ds.select(averageSalary) + result.show() + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:typed_custom_aggregation$ + + spark.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala new file mode 100644 index 0000000000000..9c9ebc55163de --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +// $example on:untyped_custom_aggregation$ +import org.apache.spark.sql.expressions.MutableAggregationBuffer +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession +// $example off:untyped_custom_aggregation$ + +object UserDefinedUntypedAggregation { + + // $example on:untyped_custom_aggregation$ + object MyAverage extends UserDefinedAggregateFunction { + // Data types of input arguments of this aggregate function + def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) + // Data types of values in the aggregation buffer + def bufferSchema: StructType = { + StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) + } + // The data type of the returned value + def dataType: DataType = DoubleType + // Whether this function always returns the same output on the identical input + def deterministic: Boolean = true + // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to + // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides + // the opportunity to update its values. Note that arrays and maps inside the buffer are still + // immutable. + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + buffer(1) = 0L + } + // Updates the given aggregation buffer `buffer` with new input data from `input` + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0)) { + buffer(0) = buffer.getLong(0) + input.getLong(0) + buffer(1) = buffer.getLong(1) + 1 + } + } + // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) + } + // Calculates the final result + def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1) + } + // $example off:untyped_custom_aggregation$ + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("Spark SQL user-defined DataFrames aggregation example") + .getOrCreate() + + // $example on:untyped_custom_aggregation$ + // Register the function to access it + spark.udf.register("myAverage", MyAverage) + + val df = spark.read.json("examples/src/main/resources/employees.json") + df.createOrReplaceTempView("employees") + df.show() + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees") + result.show() + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:untyped_custom_aggregation$ + + spark.stop() + } + +}