Skip to content

Commit

Permalink
GH-38033: [R] Allow code() to return package name prefix. (#38144)
Browse files Browse the repository at this point in the history
### Rationale for this change
#38033 

### What changes are included in this PR?
- ~~Added `get_pkg_ns()` helper.~~
- ~~Added `call_name` private method to `DataType` class to store the string name used in the code call. Refactored `code()` public method to use `call_name`.~~
- Converted all `$code() call(...)` to `$code(namespace = FALSE) call2(..., .ns = if(namespace) "arrow")` in `DataType`, `Schema`, and `DictionaryType`.
- Added `code` to `Schema` docstring.
- Updated `expect_code_roundtrip` to test roundtrip with and without namespace, and check match/no match for `arrow::` depending on namespace argument.

### Are these changes tested?
* All tests pass, including lintr checks.

### Are there any user-facing changes?
Yes, user-facing changes, but no breaking changes to any public APIs.
* Closes: #38033

Lead-authored-by: orgadish <[email protected]>
Co-authored-by: Nic Crane <[email protected]>
Signed-off-by: Nic Crane <[email protected]>
  • Loading branch information
orgadish and thisisnic authored Oct 19, 2023
1 parent 02ad5ae commit a0e58f1
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 39 deletions.
8 changes: 4 additions & 4 deletions r/R/dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ DictionaryType <- R6Class("DictionaryType",
ToString = function() {
prettier_dictionary_type(DataType__ToString(self))
},
code = function() {
code = function(namespace = FALSE) {
details <- list()
if (self$index_type != int32()) {
details$index_type <- self$index_type$code()
details$index_type <- self$index_type$code(namespace)
}
if (self$value_type != utf8()) {
details$value_type <- self$value_type$code()
details$value_type <- self$value_type$code(namespace)
}
if (isTRUE(self$ordered)) {
details$ordered <- TRUE
}
call2("dictionary", !!!details)
call2("dictionary", !!!details, .ns = if (namespace) "arrow")
}
),
active = list(
Expand Down
8 changes: 4 additions & 4 deletions r/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#' - `$WithMetadata(metadata)`: returns a new `Schema` with the key-value
#' `metadata` set. Note that all list elements in `metadata` will be coerced
#' to `character`.
#' - `$code(namespace)`: returns the R code needed to generate this schema. Use `namespace=TRUE` to call with `arrow::`.
#'
#' @section Active bindings:
#'
Expand Down Expand Up @@ -107,14 +108,13 @@ Schema <- R6Class("Schema",
inherits(other, "Schema") && Schema__Equals(self, other, isTRUE(check_metadata))
},
export_to_c = function(ptr) ExportSchema(self, ptr),
code = function() {
code = function(namespace = FALSE) {
names <- self$names
codes <- map2(names, self$fields, function(name, field) {
field$type$code()
field$type$code(namespace)
})
codes <- set_names(codes, names)

call2("schema", !!!codes)
call2("schema", !!!codes, .ns = if (namespace) "arrow")
},
WithNames = function(names) {
if (!inherits(names, "character")) {
Expand Down
62 changes: 33 additions & 29 deletions r/R/type.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' - `$ToString()`: String representation of the DataType
#' - `$Equals(other)`: Is the DataType equal to `other`
#' - `$fields()`: The children fields associated with this type
#' - `$code()`: Produces an R call of the data type.
#' - `$code(namespace)`: Produces an R call of the data type. Use `namespace=TRUE` to call with `arrow::`.
#'
#' There are also some active bindings:
#' - `$id`: integer Arrow type id.
Expand All @@ -51,7 +51,7 @@ DataType <- R6Class("DataType",
DataType__fields(self)
},
export_to_c = function(ptr) ExportType(self, ptr),
code = function() call("stop", paste0("Unsupported type: <", self$ToString(), ">."))
code = function(namespace = FALSE) call("stop", paste0("Unsupported type: <", self$ToString(), ">."))
),
active = list(
id = function() DataType__id(self),
Expand Down Expand Up @@ -158,7 +158,7 @@ infer_type.Expression <- function(x, ...) x$type()
FixedWidthType <- R6Class("FixedWidthType",
inherit = DataType,
public = list(
code = function() call(tolower(self$name))
code = function(namespace = FALSE) call2(tolower(self$name), .ns = if (namespace) "arrow")
),
active = list(
bit_width = function() FixedWidthType__bit_width(self)
Expand All @@ -178,45 +178,47 @@ Float32 <- R6Class("Float32", inherit = FixedWidthType)
Float64 <- R6Class("Float64",
inherit = FixedWidthType,
public = list(
code = function() call("float64")
code = function(namespace = FALSE) call2("float64", .ns = if (namespace) "arrow")
)
)
Boolean <- R6Class("Boolean", inherit = FixedWidthType)
Utf8 <- R6Class("Utf8",
inherit = DataType,
public = list(
code = function() call("utf8")
code = function(namespace = FALSE) call2("utf8", .ns = if (namespace) "arrow")
)
)
LargeUtf8 <- R6Class("LargeUtf8",
inherit = DataType,
public = list(
code = function() call("large_utf8")
code = function(namespace = FALSE) call2("large_utf8", .ns = if (namespace) "arrow")
)
)
Binary <- R6Class("Binary",
inherit = DataType,
public = list(
code = function() call("binary")
code = function(namespace = FALSE) call2("binary", .ns = if (namespace) "arrow")
)
)
LargeBinary <- R6Class("LargeBinary",
inherit = DataType, public = list(
code = function() call("large_binary")
code = function(namespace = FALSE) call2("large_binary", .ns = if (namespace) "arrow")
)
)
FixedSizeBinary <- R6Class("FixedSizeBinary",
inherit = FixedWidthType,
public = list(
byte_width = function() FixedSizeBinary__byte_width(self),
code = function() call2("fixed_size_binary", byte_width = self$byte_width())
code = function(namespace = FALSE) {
call2("fixed_size_binary", byte_width = self$byte_width(), .ns = if (namespace) "arrow")
}
)
)

DateType <- R6Class("DateType",
inherit = FixedWidthType,
public = list(
code = function() call2(tolower(self$name)),
code = function(namespace = FALSE) call2(tolower(self$name), .ns = if (namespace) "arrow"),
unit = function() DateType__unit(self)
)
)
Expand All @@ -232,26 +234,26 @@ TimeType <- R6Class("TimeType",
Time32 <- R6Class("Time32",
inherit = TimeType,
public = list(
code = function() {
code = function(namespace = FALSE) {
unit <- if (self$unit() == TimeUnit$MILLI) {
"ms"
} else {
"s"
}
call2("time32", unit = unit)
call2("time32", unit = unit, .ns = if (namespace) "arrow")
}
)
)
Time64 <- R6Class("Time64",
inherit = TimeType,
public = list(
code = function() {
code = function(namespace = FALSE) {
unit <- if (self$unit() == TimeUnit$NANO) {
"ns"
} else {
"us"
}
call2("time64", unit = unit)
call2("time64", unit = unit, .ns = if (namespace) "arrow")
}
)
)
Expand All @@ -266,20 +268,20 @@ DurationType <- R6Class("DurationType",
Null <- R6Class("Null",
inherit = DataType,
public = list(
code = function() call("null")
code = function(namespace = FALSE) call2("null", .ns = if (namespace) "arrow")
)
)

Timestamp <- R6Class("Timestamp",
inherit = FixedWidthType,
public = list(
code = function() {
code = function(namespace = FALSE) {
unit <- c("s", "ms", "us", "ns")[self$unit() + 1L]
tz <- self$timezone()
if (identical(tz, "")) {
call2("timestamp", unit = unit)
call2("timestamp", unit = unit, .ns = if (namespace) "arrow")
} else {
call2("timestamp", unit = unit, timezone = tz)
call2("timestamp", unit = unit, timezone = tz, .ns = if (namespace) "arrow")
}
},
timezone = function() TimestampType__timezone(self),
Expand All @@ -290,8 +292,8 @@ Timestamp <- R6Class("Timestamp",
DecimalType <- R6Class("DecimalType",
inherit = FixedWidthType,
public = list(
code = function() {
call2("decimal", precision = self$precision(), scale = self$scale())
code = function(namespace = FALSE) {
call2("decimal", precision = self$precision(), scale = self$scale(), .ns = if (namespace) "arrow")
},
precision = function() DecimalType__precision(self),
scale = function() DecimalType__scale(self)
Expand Down Expand Up @@ -624,13 +626,13 @@ check_decimal_args <- function(precision, scale) {
StructType <- R6Class("StructType",
inherit = NestedType,
public = list(
code = function() {
code = function(namespace = FALSE) {
field_names <- StructType__field_names(self)
codes <- map(field_names, function(name) {
self$GetFieldByName(name)$type$code()
self$GetFieldByName(name)$type$code(namespace)
})
codes <- set_names(codes, field_names)
call2("struct", !!!codes)
call2("struct", !!!codes, .ns = if (namespace) "arrow")
},
GetFieldByName = function(name) StructType__GetFieldByName(self, name),
GetFieldIndex = function(name) StructType__GetFieldIndex(self, name)
Expand All @@ -648,8 +650,8 @@ names.StructType <- function(x) StructType__field_names(x)
ListType <- R6Class("ListType",
inherit = NestedType,
public = list(
code = function() {
call("list_of", self$value_type$code())
code = function(namespace = FALSE) {
call2("list_of", self$value_type$code(namespace), .ns = if (namespace) "arrow")
}
),
active = list(
Expand All @@ -665,8 +667,8 @@ list_of <- function(type) list__(type)
LargeListType <- R6Class("LargeListType",
inherit = NestedType,
public = list(
code = function() {
call2("large_list_of", self$value_type$code())
code = function(namespace = FALSE) {
call2("large_list_of", self$value_type$code(namespace), .ns = if (namespace) "arrow")
}
),
active = list(
Expand All @@ -684,8 +686,10 @@ large_list_of <- function(type) large_list__(type)
FixedSizeListType <- R6Class("FixedSizeListType",
inherit = NestedType,
public = list(
code = function() {
call2("fixed_size_list_of", self$value_type$code(), list_size = self$list_size)
code = function(namespace = FALSE) {
call2("fixed_size_list_of", self$value_type$code(namespace),
list_size = self$list_size, .ns = if (namespace) "arrow"
)
}
),
active = list(
Expand Down
2 changes: 1 addition & 1 deletion r/man/DataType-class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions r/man/Schema-class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion r/tests/testthat/helper-roundtrip.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,13 @@ expect_array_roundtrip <- function(x, type, as = NULL) {
}

expect_code_roundtrip <- function(x) {
expect_equal(eval(x$code()), x)
code <- x$code()
code_with_ns <- x$code(namespace = TRUE)

pkg_prefix_pattern <- "^arrow[:][:]"
expect_no_match(as.character(code), pkg_prefix_pattern)
expect_match(as.character(code_with_ns)[1], pkg_prefix_pattern)

expect_equal(eval(code), x)
expect_equal(eval(code_with_ns), x)
}

0 comments on commit a0e58f1

Please sign in to comment.