Skip to content

Commit

Permalink
[SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

- A separate subsection for Aggregations under “Getting Started” in the Spark SQL programming guide. It mentions which aggregate functions are predefined and how users can create their own.
- Examples of using the `UserDefinedAggregateFunction` abstract class for untyped aggregations in Java and Scala.
- Examples of using the `Aggregator` abstract class for type-safe aggregations in Java and Scala.
- Python is not covered.
- The PR might not resolve the ticket since I do not know what exactly was planned by the author.

In total, there are four new standalone examples that can be executed via `spark-submit` or `run-example`. The updated Spark SQL programming guide references to these examples and does not contain hard-coded snippets.

## How was this patch tested?

The patch was tested locally by building the docs. The examples were run as well.

![image](https://cloud.githubusercontent.com/assets/6235869/21292915/04d9d084-c515-11e6-811a-999d598dffba.png)

Author: aokolnychyi <[email protected]>

Closes #16329 from aokolnychyi/SPARK-16046.
  • Loading branch information
aokolnychyi authored and gatorsmile committed Jan 25, 2017
1 parent 40a4cfc commit 3fdce81
Show file tree
Hide file tree
Showing 6 changed files with 533 additions and 0 deletions.
46 changes: 46 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,52 @@ For example:

</div>

## Aggregations

The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common
aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc.
While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in
[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and
[Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets.
Moreover, users are not limited to the predefined aggregate functions and can create their own.

### Untyped User-Defined Aggregate Functions

<div class="codetabs">

<div data-lang="scala" markdown="1">

Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction)
abstract class to implement a custom untyped aggregate function. For example, a user-defined average
can look like:

{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%}
</div>

<div data-lang="java" markdown="1">

{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
</div>

</div>

### Type-Safe User-Defined Aggregate Functions

User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class.
For example, a type-safe user-defined average can look like:
<div class="codetabs">

<div data-lang="scala" markdown="1">

{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%}
</div>

<div data-lang="java" markdown="1">

{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%}
</div>

</div>

# Data Sources

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.sql;

// $example on:typed_custom_aggregation$
import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.expressions.Aggregator;
// $example off:typed_custom_aggregation$

public class JavaUserDefinedTypedAggregation {

// $example on:typed_custom_aggregation$
public static class Employee implements Serializable {
private String name;
private long salary;

// Constructors, getters, setters...
// $example off:typed_custom_aggregation$
public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public long getSalary() {
return salary;
}

public void setSalary(long salary) {
this.salary = salary;
}
// $example on:typed_custom_aggregation$
}

public static class Average implements Serializable {
private long sum;
private long count;

// Constructors, getters, setters...
// $example off:typed_custom_aggregation$
public Average() {
}

public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}

public long getSum() {
return sum;
}

public void setSum(long sum) {
this.sum = sum;
}

public long getCount() {
return count;
}

public void setCount(long count) {
this.count = count;
}
// $example on:typed_custom_aggregation$
}

public static class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Specifies the Encoder for the intermediate value type
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
// $example off:typed_custom_aggregation$

public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined Datasets aggregation example")
.getOrCreate();

// $example on:typed_custom_aggregation$
Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
String path = "examples/src/main/resources/employees.json";
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+

MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:typed_custom_aggregation$
spark.stop();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.sql;

// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$

public class JavaUserDefinedUntypedAggregation {

// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {

private StructType inputSchema;
private StructType bufferSchema;

public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);

List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}
// $example off:untyped_custom_aggregation$

public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined DataFrames aggregation example")
.getOrCreate();

// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());

Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+

Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:untyped_custom_aggregation$

spark.stop();
}
}
4 changes: 4 additions & 0 deletions examples/src/main/resources/employees.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
Loading

0 comments on commit 3fdce81

Please sign in to comment.