diff --git a/pkg/DESCRIPTION b/pkg/DESCRIPTION index b553fdec3014c..786f8245b42dc 100644 --- a/pkg/DESCRIPTION +++ b/pkg/DESCRIPTION @@ -20,6 +20,7 @@ Collate: 'RDD.R' 'pairRDD.R' 'column.R' + 'group.R' 'DataFrame.R' 'SQLContext.R' 'broadcast.R' diff --git a/pkg/NAMESPACE b/pkg/NAMESPACE index 2471c6a526a16..087d077fd6d51 100644 --- a/pkg/NAMESPACE +++ b/pkg/NAMESPACE @@ -90,6 +90,7 @@ exportMethods("columns", "dtypes", "explain", "filter", + "groupBy", "head", "isLocal", "limit", @@ -134,6 +135,9 @@ exportMethods("asc", "countDistinct", "sumDistinct") +exportClasses("GroupedData") +exportMethods("agg") + export("cacheTable", "clearCache", "createExternalTable", diff --git a/pkg/R/DataFrame.R b/pkg/R/DataFrame.R index b729d734cb2fd..f90bf69806101 100644 --- a/pkg/R/DataFrame.R +++ b/pkg/R/DataFrame.R @@ -1,6 +1,6 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R +#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") @@ -663,6 +663,33 @@ setMethod("toRDD", }) }) +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + + +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + ############################## RDD Map Functions ################################## # All of the following functions mirror the existing RDD map functions, # # but allow for use with DataFrames by first converting to an RRDD before calling # diff --git a/pkg/R/column.R b/pkg/R/column.R index ba3343f799919..799d31072c785 100644 --- a/pkg/R/column.R +++ b/pkg/R/column.R @@ -73,7 +73,7 @@ createMethods <- function() { createOperator(op) } - setGeneric("avg", function(x) { standardGeneric("avg") }) + setGeneric("avg", function(x, ...) { standardGeneric("avg") }) setGeneric("last", function(x) { standardGeneric("last") }) setGeneric("lower", function(x) { standardGeneric("lower") }) setGeneric("upper", function(x) { standardGeneric("upper") }) diff --git a/pkg/R/group.R b/pkg/R/group.R new file mode 100644 index 0000000000000..bf17efffc4b20 --- /dev/null +++ b/pkg/R/group.R @@ -0,0 +1,81 @@ +############################## GroupedData ######################################## + +setClass("GroupedData", + slots = list(env = "environment", + sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@env <- new.env() + .Object@sgd <- sgd + .Object +}) + +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + +#' sum/mean/avg/min/max + +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/pkg/R/utils.R b/pkg/R/utils.R index eadb8f3fe278f..73775650fb9d0 100644 --- a/pkg/R/utils.R +++ b/pkg/R/utils.R @@ -370,4 +370,3 @@ toSeq <- function(...) { listToSeq <- function(l) { callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "toSeq", l) } - diff --git a/pkg/inst/tests/test_sparkSQL.R b/pkg/inst/tests/test_sparkSQL.R index 4fabec337558e..f907309a45926 100644 --- a/pkg/inst/tests/test_sparkSQL.R +++ b/pkg/inst/tests/test_sparkSQL.R @@ -329,6 +329,31 @@ test_that("column functions", { c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") }) +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_true(columns(df1) == c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + test_that("sortDF() and orderBy() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) sorted <- sortDF(df, df$age) diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala index 4d1d2c0aba934..c292c59d63f29 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackendHandler.scala @@ -1,11 +1,8 @@ package edu.berkeley.cs.amplab.sparkr -import scala.collection.mutable.HashMap +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.DataInputStream -import java.io.DataOutputStream +import scala.collection.mutable.HashMap import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.ChannelHandlerContext @@ -19,7 +16,8 @@ import edu.berkeley.cs.amplab.sparkr.SerDe._ * this across connections ? */ @Sharable -class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHandler[Array[Byte]] { +class SparkRBackendHandler(server: SparkRBackend) + extends SimpleChannelInboundHandler[Array[Byte]] { override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]) { val bis = new ByteArrayInputStream(msg) @@ -100,11 +98,18 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa val methods = cls.get.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val selectedMethod = selectedMethods.filter { x => + val methods = selectedMethods.filter { x => matchMethod(numArgs, args, x.getParameterTypes) - }.head - - val ret = selectedMethod.invoke(obj, args:_*) + } + if (methods.isEmpty) { + System.err.println(s"cannot find matching method ${cls.get}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) // Write status bit writeInt(dos, 0) @@ -160,6 +165,7 @@ class SparkRBackendHandler(server: SparkRBackend) extends SimpleChannelInboundHa } } if (!parameterWrapperType.isInstance(args(i))) { + System.err.println(s"arg $i not match: expected type $parameterWrapperType, but got ${args(i).getClass()}") return false } }