diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index f4c89e58fa431..c60088d58cf58 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -382,6 +382,52 @@ For example:
+## Aggregations
+
+The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common
+aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc.
+While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in
+[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and
+[Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets.
+Moreover, users are not limited to the predefined aggregate functions and can create their own.
+
+### Untyped User-Defined Aggregate Functions
+
+
+
+
+
+Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction)
+abstract class to implement a custom untyped aggregate function. For example, a user-defined average
+can look like:
+
+{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%}
+
+
+
+
+{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
+
+
+
+
+### Type-Safe User-Defined Aggregate Functions
+
+User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class.
+For example, a type-safe user-defined average can look like:
+
+
+
+
+{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%}
+
+
+
+
+{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%}
+
+
+
# Data Sources
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java
new file mode 100644
index 0000000000000..78e9011be4705
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql;
+
+// $example on:typed_custom_aggregation$
+import java.io.Serializable;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.expressions.Aggregator;
+// $example off:typed_custom_aggregation$
+
+public class JavaUserDefinedTypedAggregation {
+
+ // $example on:typed_custom_aggregation$
+ public static class Employee implements Serializable {
+ private String name;
+ private long salary;
+
+ // Constructors, getters, setters...
+ // $example off:typed_custom_aggregation$
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public long getSalary() {
+ return salary;
+ }
+
+ public void setSalary(long salary) {
+ this.salary = salary;
+ }
+ // $example on:typed_custom_aggregation$
+ }
+
+ public static class Average implements Serializable {
+ private long sum;
+ private long count;
+
+ // Constructors, getters, setters...
+ // $example off:typed_custom_aggregation$
+ public Average() {
+ }
+
+ public Average(long sum, long count) {
+ this.sum = sum;
+ this.count = count;
+ }
+
+ public long getSum() {
+ return sum;
+ }
+
+ public void setSum(long sum) {
+ this.sum = sum;
+ }
+
+ public long getCount() {
+ return count;
+ }
+
+ public void setCount(long count) {
+ this.count = count;
+ }
+ // $example on:typed_custom_aggregation$
+ }
+
+ public static class MyAverage extends Aggregator {
+ // A zero value for this aggregation. Should satisfy the property that any b + zero = b
+ public Average zero() {
+ return new Average(0L, 0L);
+ }
+ // Combine two values to produce a new value. For performance, the function may modify `buffer`
+ // and return it instead of constructing a new object
+ public Average reduce(Average buffer, Employee employee) {
+ long newSum = buffer.getSum() + employee.getSalary();
+ long newCount = buffer.getCount() + 1;
+ buffer.setSum(newSum);
+ buffer.setCount(newCount);
+ return buffer;
+ }
+ // Merge two intermediate values
+ public Average merge(Average b1, Average b2) {
+ long mergedSum = b1.getSum() + b2.getSum();
+ long mergedCount = b1.getCount() + b2.getCount();
+ b1.setSum(mergedSum);
+ b1.setCount(mergedCount);
+ return b1;
+ }
+ // Transform the output of the reduction
+ public Double finish(Average reduction) {
+ return ((double) reduction.getSum()) / reduction.getCount();
+ }
+ // Specifies the Encoder for the intermediate value type
+ public Encoder bufferEncoder() {
+ return Encoders.bean(Average.class);
+ }
+ // Specifies the Encoder for the final output value type
+ public Encoder outputEncoder() {
+ return Encoders.DOUBLE();
+ }
+ }
+ // $example off:typed_custom_aggregation$
+
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("Java Spark SQL user-defined Datasets aggregation example")
+ .getOrCreate();
+
+ // $example on:typed_custom_aggregation$
+ Encoder employeeEncoder = Encoders.bean(Employee.class);
+ String path = "examples/src/main/resources/employees.json";
+ Dataset ds = spark.read().json(path).as(employeeEncoder);
+ ds.show();
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ MyAverage myAverage = new MyAverage();
+ // Convert the function to a `TypedColumn` and give it a name
+ TypedColumn averageSalary = myAverage.toColumn().name("average_salary");
+ Dataset result = ds.select(averageSalary);
+ result.show();
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:typed_custom_aggregation$
+ spark.stop();
+ }
+
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java
new file mode 100644
index 0000000000000..6da60a1fc6b88
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql;
+
+// $example on:untyped_custom_aggregation$
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off:untyped_custom_aggregation$
+
+public class JavaUserDefinedUntypedAggregation {
+
+ // $example on:untyped_custom_aggregation$
+ public static class MyAverage extends UserDefinedAggregateFunction {
+
+ private StructType inputSchema;
+ private StructType bufferSchema;
+
+ public MyAverage() {
+ List inputFields = new ArrayList<>();
+ inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
+ inputSchema = DataTypes.createStructType(inputFields);
+
+ List bufferFields = new ArrayList<>();
+ bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
+ bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
+ bufferSchema = DataTypes.createStructType(bufferFields);
+ }
+ // Data types of input arguments of this aggregate function
+ public StructType inputSchema() {
+ return inputSchema;
+ }
+ // Data types of values in the aggregation buffer
+ public StructType bufferSchema() {
+ return bufferSchema;
+ }
+ // The data type of the returned value
+ public DataType dataType() {
+ return DataTypes.DoubleType;
+ }
+ // Whether this function always returns the same output on the identical input
+ public boolean deterministic() {
+ return true;
+ }
+ // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
+ // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
+ // the opportunity to update its values. Note that arrays and maps inside the buffer are still
+ // immutable.
+ public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, 0L);
+ buffer.update(1, 0L);
+ }
+ // Updates the given aggregation buffer `buffer` with new input data from `input`
+ public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ long updatedSum = buffer.getLong(0) + input.getLong(0);
+ long updatedCount = buffer.getLong(1) + 1;
+ buffer.update(0, updatedSum);
+ buffer.update(1, updatedCount);
+ }
+ }
+ // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
+ public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
+ long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
+ buffer1.update(0, mergedSum);
+ buffer1.update(1, mergedCount);
+ }
+ // Calculates the final result
+ public Double evaluate(Row buffer) {
+ return ((double) buffer.getLong(0)) / buffer.getLong(1);
+ }
+ }
+ // $example off:untyped_custom_aggregation$
+
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("Java Spark SQL user-defined DataFrames aggregation example")
+ .getOrCreate();
+
+ // $example on:untyped_custom_aggregation$
+ // Register the function to access it
+ spark.udf().register("myAverage", new MyAverage());
+
+ Dataset df = spark.read().json("examples/src/main/resources/employees.json");
+ df.createOrReplaceTempView("employees");
+ df.show();
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ Dataset result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
+ result.show();
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:untyped_custom_aggregation$
+
+ spark.stop();
+ }
+}
diff --git a/examples/src/main/resources/employees.json b/examples/src/main/resources/employees.json
new file mode 100644
index 0000000000000..6b2e6329a1cb2
--- /dev/null
+++ b/examples/src/main/resources/employees.json
@@ -0,0 +1,4 @@
+{"name":"Michael", "salary":3000}
+{"name":"Andy", "salary":4500}
+{"name":"Justin", "salary":3500}
+{"name":"Berta", "salary":4000}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala
new file mode 100644
index 0000000000000..ac617d19d36cf
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql
+
+// $example on:typed_custom_aggregation$
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.SparkSession
+// $example off:typed_custom_aggregation$
+
+object UserDefinedTypedAggregation {
+
+ // $example on:typed_custom_aggregation$
+ case class Employee(name: String, salary: Long)
+ case class Average(var sum: Long, var count: Long)
+
+ object MyAverage extends Aggregator[Employee, Average, Double] {
+ // A zero value for this aggregation. Should satisfy the property that any b + zero = b
+ def zero: Average = Average(0L, 0L)
+ // Combine two values to produce a new value. For performance, the function may modify `buffer`
+ // and return it instead of constructing a new object
+ def reduce(buffer: Average, employee: Employee): Average = {
+ buffer.sum += employee.salary
+ buffer.count += 1
+ buffer
+ }
+ // Merge two intermediate values
+ def merge(b1: Average, b2: Average): Average = {
+ b1.sum += b2.sum
+ b1.count += b2.count
+ b1
+ }
+ // Transform the output of the reduction
+ def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
+ // Specifies the Encoder for the intermediate value type
+ def bufferEncoder: Encoder[Average] = Encoders.product
+ // Specifies the Encoder for the final output value type
+ def outputEncoder: Encoder[Double] = Encoders.scalaDouble
+ }
+ // $example off:typed_custom_aggregation$
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("Spark SQL user-defined Datasets aggregation example")
+ .getOrCreate()
+
+ import spark.implicits._
+
+ // $example on:typed_custom_aggregation$
+ val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
+ ds.show()
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ // Convert the function to a `TypedColumn` and give it a name
+ val averageSalary = MyAverage.toColumn.name("average_salary")
+ val result = ds.select(averageSalary)
+ result.show()
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:typed_custom_aggregation$
+
+ spark.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala
new file mode 100644
index 0000000000000..9c9ebc55163de
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql
+
+// $example on:untyped_custom_aggregation$
+import org.apache.spark.sql.expressions.MutableAggregationBuffer
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.SparkSession
+// $example off:untyped_custom_aggregation$
+
+object UserDefinedUntypedAggregation {
+
+ // $example on:untyped_custom_aggregation$
+ object MyAverage extends UserDefinedAggregateFunction {
+ // Data types of input arguments of this aggregate function
+ def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
+ // Data types of values in the aggregation buffer
+ def bufferSchema: StructType = {
+ StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
+ }
+ // The data type of the returned value
+ def dataType: DataType = DoubleType
+ // Whether this function always returns the same output on the identical input
+ def deterministic: Boolean = true
+ // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
+ // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
+ // the opportunity to update its values. Note that arrays and maps inside the buffer are still
+ // immutable.
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = 0L
+ buffer(1) = 0L
+ }
+ // Updates the given aggregation buffer `buffer` with new input data from `input`
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (!input.isNullAt(0)) {
+ buffer(0) = buffer.getLong(0) + input.getLong(0)
+ buffer(1) = buffer.getLong(1) + 1
+ }
+ }
+ // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
+ buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
+ }
+ // Calculates the final result
+ def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
+ }
+ // $example off:untyped_custom_aggregation$
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("Spark SQL user-defined DataFrames aggregation example")
+ .getOrCreate()
+
+ // $example on:untyped_custom_aggregation$
+ // Register the function to access it
+ spark.udf.register("myAverage", MyAverage)
+
+ val df = spark.read.json("examples/src/main/resources/employees.json")
+ df.createOrReplaceTempView("employees")
+ df.show()
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
+ result.show()
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:untyped_custom_aggregation$
+
+ spark.stop()
+ }
+
+}