Skip to content

Commit

Permalink
Merge pull request apache#180 from davies/group
Browse files Browse the repository at this point in the history
[SPARKR-191] groupBy and agg() API for DataFrame
  • Loading branch information
shivaram committed Mar 5, 2015
2 parents 4d0fb56 + 9dd6a5a commit bcb0bf5
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 13 deletions.
1 change: 1 addition & 0 deletions pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Collate:
'RDD.R'
'pairRDD.R'
'column.R'
'group.R'
'DataFrame.R'
'SQLContext.R'
'broadcast.R'
Expand Down
4 changes: 4 additions & 0 deletions pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ exportMethods("columns",
"dtypes",
"explain",
"filter",
"groupBy",
"head",
"isLocal",
"limit",
Expand Down Expand Up @@ -134,6 +135,9 @@ exportMethods("asc",
"countDistinct",
"sumDistinct")

exportClasses("GroupedData")
exportMethods("agg")

export("cacheTable",
"clearCache",
"createExternalTable",
Expand Down
29 changes: 28 additions & 1 deletion pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# DataFrame.R - DataFrame class and methods implemented in S4 OO classes

#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R
#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R
NULL

setOldClass("jobj")
Expand Down Expand Up @@ -663,6 +663,33 @@ setMethod("toRDD",
})
})

#' GroupBy
#'
#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
#'
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })

setMethod("groupBy",
signature(x = "DataFrame"),
function(x, ...) {
cols <- list(...)
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
} else {
jcol <- lapply(cols, function(c) { c@jc })
sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
}
groupedData(sgd)
})


setMethod("agg",
signature(x = "DataFrame"),
function(x, ...) {
agg(groupBy(x), ...)
})


############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
# but allow for use with DataFrames by first converting to an RRDD before calling #
Expand Down
2 changes: 1 addition & 1 deletion pkg/R/column.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ createMethods <- function() {
createOperator(op)
}

setGeneric("avg", function(x) { standardGeneric("avg") })
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
setGeneric("last", function(x) { standardGeneric("last") })
setGeneric("lower", function(x) { standardGeneric("lower") })
setGeneric("upper", function(x) { standardGeneric("upper") })
Expand Down
81 changes: 81 additions & 0 deletions pkg/R/group.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
############################## GroupedData ########################################

setClass("GroupedData",
slots = list(env = "environment",
sgd = "jobj"))

setMethod("initialize", "GroupedData", function(.Object, sgd) {
.Object@env <- new.env()
.Object@sgd <- sgd
.Object
})

groupedData <- function(sgd) {
new("GroupedData", sgd)
}

setMethod("count",
signature(x = "GroupedData"),
function(x) {
dataFrame(callJMethod(x@sgd, "count"))
})

#' Agg
#'
#' Aggregates on the entire DataFrame without groups.
#'
#' df2 <- agg(df, <column> = <aggFunction>)
#' df2 <- agg(df, newColName = aggFunction(column))
#' @examples
#' \dontrun{
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
#' }
setGeneric("agg", function (x, ...) { standardGeneric("agg") })

setMethod("agg",
signature(x = "GroupedData"),
function(x, ...) {
cols = list(...)
stopifnot(length(cols) > 0)
if (is.character(cols[[1]])) {
cols <- varargsToEnv(...)
sdf <- callJMethod(x@sgd, "agg", cols)
} else if (class(cols[[1]]) == "Column") {
ns <- names(cols)
if (!is.null(ns)) {
for (n in ns) {
if (n != "") {
cols[[n]] = alias(cols[[n]], n)
}
}
}
jcols <- lapply(cols, function(c) { c@jc })
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
} else {
stop("agg can only support Column or character")
}
dataFrame(sdf)
})

#' sum/mean/avg/min/max

methods <- c("sum", "mean", "avg", "min", "max")

createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
function(x, ...) {
sdf <- callJMethod(x@sgd, name, toSeq(...))
dataFrame(sdf)
})
}

createMethods <- function() {
for (name in methods) {
createMethod(name)
}
}

createMethods()

1 change: 0 additions & 1 deletion pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,3 @@ toSeq <- function(...) {
listToSeq <- function(l) {
callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "toSeq", l)
}

25 changes: 25 additions & 0 deletions pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,31 @@ test_that("column functions", {
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
})

test_that("group by", {
df <- jsonFile(sqlCtx, jsonPath)
df1 <- agg(df, name = "max", age = "sum")
expect_true(1 == count(df1))
df1 <- agg(df, age2 = max(df$age))
expect_true(1 == count(df1))
expect_true(columns(df1) == c("age2"))

gd <- groupBy(df, "name")
expect_true(inherits(gd, "GroupedData"))
df2 <- count(gd)
expect_true(inherits(df2, "DataFrame"))
expect_true(3 == count(df2))

df3 <- agg(gd, age = "sum")
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))

df4 <- sum(gd, "age")
expect_true(inherits(df4, "DataFrame"))
expect_true(3 == count(df4))
expect_true(3 == count(mean(gd, "age")))
expect_true(3 == count(max(gd, "age")))
})

test_that("sortDF() and orderBy() on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
sorted <- sortDF(df, df$age)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package edu.berkeley.cs.amplab.sparkr

import scala.collection.mutable.HashMap
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.DataInputStream
import java.io.DataOutputStream
import scala.collection.mutable.HashMap

import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
Expand All @@ -19,7 +16,8 @@ import edu.berkeley.cs.amplab.sparkr.SerDe._
* this across connections ?
*/
@Sharable
class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHandler[Array[Byte]] {
class SparkRBackendHandler(server: SparkRBackend)
extends SimpleChannelInboundHandler[Array[Byte]] {

override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]) {
val bis = new ByteArrayInputStream(msg)
Expand Down Expand Up @@ -100,11 +98,18 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa
val methods = cls.get.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
val selectedMethod = selectedMethods.filter { x =>
val methods = selectedMethods.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}.head

val ret = selectedMethod.invoke(obj, args:_*)
}
if (methods.isEmpty) {
System.err.println(s"cannot find matching method ${cls.get}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})")
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
val ret = methods.head.invoke(obj, args:_*)

// Write status bit
writeInt(dos, 0)
Expand Down Expand Up @@ -160,6 +165,7 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa
}
}
if (!parameterWrapperType.isInstance(args(i))) {
System.err.println(s"arg $i not match: expected type $parameterWrapperType, but got ${args(i).getClass()}")
return false
}
}
Expand Down

0 comments on commit bcb0bf5

Please sign in to comment.