-
Notifications
You must be signed in to change notification settings - Fork 28.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide
## What changes were proposed in this pull request? - A separate subsection for Aggregations under “Getting Started” in the Spark SQL programming guide. It mentions which aggregate functions are predefined and how users can create their own. - Examples of using the `UserDefinedAggregateFunction` abstract class for untyped aggregations in Java and Scala. - Examples of using the `Aggregator` abstract class for type-safe aggregations in Java and Scala. - Python is not covered. - The PR might not resolve the ticket since I do not know what exactly was planned by the author. In total, there are four new standalone examples that can be executed via `spark-submit` or `run-example`. The updated Spark SQL programming guide references to these examples and does not contain hard-coded snippets. ## How was this patch tested? The patch was tested locally by building the docs. The examples were run as well. ![image](https://cloud.githubusercontent.com/assets/6235869/21292915/04d9d084-c515-11e6-811a-999d598dffba.png) Author: aokolnychyi <[email protected]> Closes #16329 from aokolnychyi/SPARK-16046.
- Loading branch information
1 parent
40a4cfc
commit 3fdce81
Showing
6 changed files
with
533 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
160 changes: 160 additions & 0 deletions
160
examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.examples.sql; | ||
|
||
// $example on:typed_custom_aggregation$ | ||
import java.io.Serializable; | ||
|
||
import org.apache.spark.sql.Dataset; | ||
import org.apache.spark.sql.Encoder; | ||
import org.apache.spark.sql.Encoders; | ||
import org.apache.spark.sql.SparkSession; | ||
import org.apache.spark.sql.TypedColumn; | ||
import org.apache.spark.sql.expressions.Aggregator; | ||
// $example off:typed_custom_aggregation$ | ||
|
||
public class JavaUserDefinedTypedAggregation { | ||
|
||
// $example on:typed_custom_aggregation$ | ||
public static class Employee implements Serializable { | ||
private String name; | ||
private long salary; | ||
|
||
// Constructors, getters, setters... | ||
// $example off:typed_custom_aggregation$ | ||
public String getName() { | ||
return name; | ||
} | ||
|
||
public void setName(String name) { | ||
this.name = name; | ||
} | ||
|
||
public long getSalary() { | ||
return salary; | ||
} | ||
|
||
public void setSalary(long salary) { | ||
this.salary = salary; | ||
} | ||
// $example on:typed_custom_aggregation$ | ||
} | ||
|
||
public static class Average implements Serializable { | ||
private long sum; | ||
private long count; | ||
|
||
// Constructors, getters, setters... | ||
// $example off:typed_custom_aggregation$ | ||
public Average() { | ||
} | ||
|
||
public Average(long sum, long count) { | ||
this.sum = sum; | ||
this.count = count; | ||
} | ||
|
||
public long getSum() { | ||
return sum; | ||
} | ||
|
||
public void setSum(long sum) { | ||
this.sum = sum; | ||
} | ||
|
||
public long getCount() { | ||
return count; | ||
} | ||
|
||
public void setCount(long count) { | ||
this.count = count; | ||
} | ||
// $example on:typed_custom_aggregation$ | ||
} | ||
|
||
public static class MyAverage extends Aggregator<Employee, Average, Double> { | ||
// A zero value for this aggregation. Should satisfy the property that any b + zero = b | ||
public Average zero() { | ||
return new Average(0L, 0L); | ||
} | ||
// Combine two values to produce a new value. For performance, the function may modify `buffer` | ||
// and return it instead of constructing a new object | ||
public Average reduce(Average buffer, Employee employee) { | ||
long newSum = buffer.getSum() + employee.getSalary(); | ||
long newCount = buffer.getCount() + 1; | ||
buffer.setSum(newSum); | ||
buffer.setCount(newCount); | ||
return buffer; | ||
} | ||
// Merge two intermediate values | ||
public Average merge(Average b1, Average b2) { | ||
long mergedSum = b1.getSum() + b2.getSum(); | ||
long mergedCount = b1.getCount() + b2.getCount(); | ||
b1.setSum(mergedSum); | ||
b1.setCount(mergedCount); | ||
return b1; | ||
} | ||
// Transform the output of the reduction | ||
public Double finish(Average reduction) { | ||
return ((double) reduction.getSum()) / reduction.getCount(); | ||
} | ||
// Specifies the Encoder for the intermediate value type | ||
public Encoder<Average> bufferEncoder() { | ||
return Encoders.bean(Average.class); | ||
} | ||
// Specifies the Encoder for the final output value type | ||
public Encoder<Double> outputEncoder() { | ||
return Encoders.DOUBLE(); | ||
} | ||
} | ||
// $example off:typed_custom_aggregation$ | ||
|
||
public static void main(String[] args) { | ||
SparkSession spark = SparkSession | ||
.builder() | ||
.appName("Java Spark SQL user-defined Datasets aggregation example") | ||
.getOrCreate(); | ||
|
||
// $example on:typed_custom_aggregation$ | ||
Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class); | ||
String path = "examples/src/main/resources/employees.json"; | ||
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder); | ||
ds.show(); | ||
// +-------+------+ | ||
// | name|salary| | ||
// +-------+------+ | ||
// |Michael| 3000| | ||
// | Andy| 4500| | ||
// | Justin| 3500| | ||
// | Berta| 4000| | ||
// +-------+------+ | ||
|
||
MyAverage myAverage = new MyAverage(); | ||
// Convert the function to a `TypedColumn` and give it a name | ||
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary"); | ||
Dataset<Double> result = ds.select(averageSalary); | ||
result.show(); | ||
// +--------------+ | ||
// |average_salary| | ||
// +--------------+ | ||
// | 3750.0| | ||
// +--------------+ | ||
// $example off:typed_custom_aggregation$ | ||
spark.stop(); | ||
} | ||
|
||
} |
132 changes: 132 additions & 0 deletions
132
examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.examples.sql; | ||
|
||
// $example on:untyped_custom_aggregation$ | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import org.apache.spark.sql.Dataset; | ||
import org.apache.spark.sql.Row; | ||
import org.apache.spark.sql.SparkSession; | ||
import org.apache.spark.sql.expressions.MutableAggregationBuffer; | ||
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; | ||
import org.apache.spark.sql.types.DataType; | ||
import org.apache.spark.sql.types.DataTypes; | ||
import org.apache.spark.sql.types.StructField; | ||
import org.apache.spark.sql.types.StructType; | ||
// $example off:untyped_custom_aggregation$ | ||
|
||
public class JavaUserDefinedUntypedAggregation { | ||
|
||
// $example on:untyped_custom_aggregation$ | ||
public static class MyAverage extends UserDefinedAggregateFunction { | ||
|
||
private StructType inputSchema; | ||
private StructType bufferSchema; | ||
|
||
public MyAverage() { | ||
List<StructField> inputFields = new ArrayList<>(); | ||
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); | ||
inputSchema = DataTypes.createStructType(inputFields); | ||
|
||
List<StructField> bufferFields = new ArrayList<>(); | ||
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); | ||
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); | ||
bufferSchema = DataTypes.createStructType(bufferFields); | ||
} | ||
// Data types of input arguments of this aggregate function | ||
public StructType inputSchema() { | ||
return inputSchema; | ||
} | ||
// Data types of values in the aggregation buffer | ||
public StructType bufferSchema() { | ||
return bufferSchema; | ||
} | ||
// The data type of the returned value | ||
public DataType dataType() { | ||
return DataTypes.DoubleType; | ||
} | ||
// Whether this function always returns the same output on the identical input | ||
public boolean deterministic() { | ||
return true; | ||
} | ||
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to | ||
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides | ||
// the opportunity to update its values. Note that arrays and maps inside the buffer are still | ||
// immutable. | ||
public void initialize(MutableAggregationBuffer buffer) { | ||
buffer.update(0, 0L); | ||
buffer.update(1, 0L); | ||
} | ||
// Updates the given aggregation buffer `buffer` with new input data from `input` | ||
public void update(MutableAggregationBuffer buffer, Row input) { | ||
if (!input.isNullAt(0)) { | ||
long updatedSum = buffer.getLong(0) + input.getLong(0); | ||
long updatedCount = buffer.getLong(1) + 1; | ||
buffer.update(0, updatedSum); | ||
buffer.update(1, updatedCount); | ||
} | ||
} | ||
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1` | ||
public void merge(MutableAggregationBuffer buffer1, Row buffer2) { | ||
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); | ||
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); | ||
buffer1.update(0, mergedSum); | ||
buffer1.update(1, mergedCount); | ||
} | ||
// Calculates the final result | ||
public Double evaluate(Row buffer) { | ||
return ((double) buffer.getLong(0)) / buffer.getLong(1); | ||
} | ||
} | ||
// $example off:untyped_custom_aggregation$ | ||
|
||
public static void main(String[] args) { | ||
SparkSession spark = SparkSession | ||
.builder() | ||
.appName("Java Spark SQL user-defined DataFrames aggregation example") | ||
.getOrCreate(); | ||
|
||
// $example on:untyped_custom_aggregation$ | ||
// Register the function to access it | ||
spark.udf().register("myAverage", new MyAverage()); | ||
|
||
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json"); | ||
df.createOrReplaceTempView("employees"); | ||
df.show(); | ||
// +-------+------+ | ||
// | name|salary| | ||
// +-------+------+ | ||
// |Michael| 3000| | ||
// | Andy| 4500| | ||
// | Justin| 3500| | ||
// | Berta| 4000| | ||
// +-------+------+ | ||
|
||
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees"); | ||
result.show(); | ||
// +--------------+ | ||
// |average_salary| | ||
// +--------------+ | ||
// | 3750.0| | ||
// +--------------+ | ||
// $example off:untyped_custom_aggregation$ | ||
|
||
spark.stop(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{"name":"Michael", "salary":3000} | ||
{"name":"Andy", "salary":4500} | ||
{"name":"Justin", "salary":3500} | ||
{"name":"Berta", "salary":4000} |
Oops, something went wrong.