Skip to content

Commit

Permalink
[SPARK-19342][SPARKR] bug fixed in collect method for collecting time…
Browse files Browse the repository at this point in the history
…stamp column

## What changes were proposed in this pull request?

Fix a bug in collect method for collecting timestamp column, the bug can be reproduced as shown in the following codes and outputs:

```
library(SparkR)
sparkR.session(master = "local")
df <- data.frame(col1 = c(0, 1, 2),
                 col2 = c(as.POSIXct("2017-01-01 00:00:01"), NA, as.POSIXct("2017-01-01 12:00:01")))

sdf1 <- createDataFrame(df)
print(dtypes(sdf1))
df1 <- collect(sdf1)
print(lapply(df1, class))

sdf2 <- filter(sdf1, "col1 > 0")
print(dtypes(sdf2))
df2 <- collect(sdf2)
print(lapply(df2, class))
```

As we can see from the printed output, the column type of col2 in df2 is converted to numeric unexpectedly, when NA exists at the top of the column.

This is caused by method `do.call(c, list)`, if we convert a list, i.e. `do.call(c, list(NA, as.POSIXct("2017-01-01 12:00:01"))`, the class of the result is numeric instead of POSIXct.

Therefore, we need to cast the data type of the vector explicitly.

## How was this patch tested?

The patch can be tested manually with the same code above.

Author: titicaca <[email protected]>

Closes #16689 from titicaca/sparkr-dev.

(cherry picked from commit bc0a0e6)
Signed-off-by: Felix Cheung <[email protected]>
  • Loading branch information
titicaca authored and Felix Cheung committed Feb 12, 2017
1 parent e580bb0 commit 173c238
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
3 changes: 2 additions & 1 deletion R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ setMethod("coltypes",
type <- PRIMITIVE_TYPES[[specialtype]]
}
}
type
type[[1]]
})

# Find which types don't have mapping to R
Expand Down Expand Up @@ -1132,6 +1132,7 @@ setMethod("collect",
if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") {
vec <- do.call(c, col)
stopifnot(class(vec) != "list")
class(vec) <- PRIMITIVE_TYPES[[colType]]
df[[colIndex]] <- vec
} else {
df[[colIndex]] <- col
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/types.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ PRIMITIVE_TYPES <- as.environment(list(
"string" = "character",
"binary" = "raw",
"boolean" = "logical",
"timestamp" = "POSIXct",
"timestamp" = c("POSIXct", "POSIXt"),
"date" = "Date",
# following types are not SQL types returned by dtypes(). They are listed here for usage
# by checkType() in schema.R.
Expand Down
42 changes: 40 additions & 2 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,9 @@ test_that("column functions", {

# Test first(), last()
df <- read.json(jsonPath)
expect_equal(collect(select(df, first(df$age)))[[1]], NA)
expect_equal(collect(select(df, first(df$age)))[[1]], NA_real_)
expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30)
expect_equal(collect(select(df, first("age")))[[1]], NA)
expect_equal(collect(select(df, first("age")))[[1]], NA_real_)
expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30)
expect_equal(collect(select(df, last(df$age)))[[1]], 19)
expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19)
Expand Down Expand Up @@ -2767,6 +2767,44 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume
"Unnamed arguments ignored: 2, 3, a.")
})

test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", {
ldf <- data.frame(col1 = c(0, 1, 2),
col2 = c(as.POSIXct("2017-01-01 00:00:01"),
NA,
as.POSIXct("2017-01-01 12:00:01")),
col3 = c(as.POSIXlt("2016-01-01 00:59:59"),
NA,
as.POSIXlt("2016-01-01 12:01:01")))
sdf1 <- createDataFrame(ldf)
ldf1 <- collect(sdf1)
expect_equal(dtypes(sdf1), list(c("col1", "double"),
c("col2", "timestamp"),
c("col3", "timestamp")))
expect_equal(class(ldf1$col1), "numeric")
expect_equal(class(ldf1$col2), c("POSIXct", "POSIXt"))
expect_equal(class(ldf1$col3), c("POSIXct", "POSIXt"))

# Columns with NAs at the top
sdf2 <- filter(sdf1, "col1 > 1")
ldf2 <- collect(sdf2)
expect_equal(dtypes(sdf2), list(c("col1", "double"),
c("col2", "timestamp"),
c("col3", "timestamp")))
expect_equal(class(ldf2$col1), "numeric")
expect_equal(class(ldf2$col2), c("POSIXct", "POSIXt"))
expect_equal(class(ldf2$col3), c("POSIXct", "POSIXt"))

# Columns with only NAs, the type will also be cast to PRIMITIVE_TYPE
sdf3 <- filter(sdf1, "col1 == 0")
ldf3 <- collect(sdf3)
expect_equal(dtypes(sdf3), list(c("col1", "double"),
c("col2", "timestamp"),
c("col3", "timestamp")))
expect_equal(class(ldf3$col1), "numeric")
expect_equal(class(ldf3$col2), c("POSIXct", "POSIXt"))
expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt"))
})

unlink(parquetPath)
unlink(orcPath)
unlink(jsonPath)
Expand Down

0 comments on commit 173c238

Please sign in to comment.