Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement $str$head() and $str$tail() #1074

Merged
merged 3 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- `pl$read_ipc()` can read a raw vector of Apache Arrow IPC file (#1072).
- New method `<DataFrame>$to_raw_ipc()` to serialize a DataFrame to a raw vector
of Apache Arrow IPC file format (#1072).
- New methods `$str$head()` and `$str$tail()` (#1074).

## Polars R Package 0.16.3

Expand Down
70 changes: 70 additions & 0 deletions R/expr__string.R
Original file line number Diff line number Diff line change
Expand Up @@ -1028,3 +1028,73 @@ ExprStr_find = function(pattern, ..., literal = FALSE, strict = TRUE) {
.pr$Expr$str_find(self, pattern, literal, strict) |>
unwrap("in $str$find():")
}

#' Return the first n characters of each string
#'
#' @param n Length of the slice (integer or expression). Strings are parsed as
#' column names. Negative indexing is supported.
#'
#' @details
#' The `n` input is defined in terms of the number of characters in the (UTF-8)
#' string. A character is defined as a Unicode scalar value. A single character
#' is represented by a single byte when working with ASCII text, and a maximum
#' of 4 bytes otherwise.
#'
#' When the `n` input is negative, `head()` returns characters up to the `n`th
#' from the end of the string. For example, if `n = -3`, then all characters
#' except the last three are returned.
#'
#' If the length of the string has fewer than `n` characters, the full string is
#' returned.
#'
#' @return A string Expr
etiennebacher marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @examples
#' df = pl$DataFrame(
#' s = c("pear", NA, "papaya", "dragonfruit"),
#' n = c(3, 4, -2, -5)
#' )
#'
#' df$with_columns(
#' s_head_5 = pl$col("s")$str$head(5),
#' s_head_n = pl$col("s")$str$head("n")
#' )
ExprStr_head = function(n) {
.pr$Expr$str_head(self, n) |>
unwrap("in $str$head():")
}

#' Return the last n characters of each string
#'
#' @param n Length of the slice (integer or expression). Strings are parsed as
#' column names. Negative indexing is supported.
etiennebacher marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @details
#' The `n` input is defined in terms of the number of characters in the (UTF-8)
#' string. A character is defined as a Unicode scalar value. A single character
#' is represented by a single byte when working with ASCII text, and a maximum
#' of 4 bytes otherwise.
#'
#' When the `n` input is negative, `tail()` returns characters starting from the
#' `n`th from the beginning of the string. For example, if `n = -3`, then all
#' characters except the first three are returned.
#'
#' If the length of the string has fewer than `n` characters, the full string is
#' returned.
#'
#' @return A string Expr
#'
#' @examples
#' df = pl$DataFrame(
#' s = c("pear", NA, "papaya", "dragonfruit"),
#' n = c(3, 4, -2, -5)
#' )
#'
#' df$with_columns(
#' s_tail_5 = pl$col("s")$str$tail(5),
#' s_tail_n = pl$col("s")$str$tail("n")
#' )
ExprStr_tail = function(n) {
.pr$Expr$str_tail(self, n) |>
unwrap("in $str$tail():")
}
4 changes: 4 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,10 @@ RPolarsExpr$str_replace_many <- function(patterns, replace_with, ascii_case_inse

RPolarsExpr$str_find <- function(pat, literal, strict) .Call(wrap__RPolarsExpr__str_find, self, pat, literal, strict)

RPolarsExpr$str_head <- function(n) .Call(wrap__RPolarsExpr__str_head, self, n)

RPolarsExpr$str_tail <- function(n) .Call(wrap__RPolarsExpr__str_tail, self, n)

RPolarsExpr$bin_contains <- function(lit) .Call(wrap__RPolarsExpr__bin_contains, self, lit)

RPolarsExpr$bin_starts_with <- function(sub) .Call(wrap__RPolarsExpr__bin_starts_with, self, sub)
Expand Down
42 changes: 42 additions & 0 deletions man/ExprStr_head.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions man/ExprStr_tail.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,15 @@ impl RPolarsExpr {
_ => Ok(self.0.clone().str().find(pat, strict).into()),
}
}

fn str_head(&self, n: Robj) -> RResult<Self> {
Ok(self.0.clone().str().head(robj_to!(PLExprCol, n)?).into())
}

fn str_tail(&self, n: Robj) -> RResult<Self> {
Ok(self.0.clone().str().tail(robj_to!(PLExprCol, n)?).into())
}

//binary methods
pub fn bin_contains(&self, lit: Robj) -> RResult<Self> {
Ok(self
Expand Down
47 changes: 24 additions & 23 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,29 +414,30 @@
[271] "str_count_matches" "str_ends_with"
[273] "str_explode" "str_extract"
[275] "str_extract_all" "str_extract_groups"
[277] "str_find" "str_hex_decode"
[279] "str_hex_encode" "str_json_decode"
[281] "str_json_path_match" "str_len_bytes"
[283] "str_len_chars" "str_pad_end"
[285] "str_pad_start" "str_replace"
[287] "str_replace_all" "str_replace_many"
[289] "str_reverse" "str_slice"
[291] "str_split" "str_split_exact"
[293] "str_splitn" "str_starts_with"
[295] "str_strip_chars" "str_strip_chars_end"
[297] "str_strip_chars_start" "str_to_date"
[299] "str_to_datetime" "str_to_integer"
[301] "str_to_lowercase" "str_to_time"
[303] "str_to_titlecase" "str_to_uppercase"
[305] "str_zfill" "struct_field_by_name"
[307] "struct_rename_fields" "sub"
[309] "sum" "tail"
[311] "tan" "tanh"
[313] "timestamp" "to_physical"
[315] "top_k" "unique"
[317] "unique_counts" "unique_stable"
[319] "upper_bound" "value_counts"
[321] "var" "xor"
[277] "str_find" "str_head"
[279] "str_hex_decode" "str_hex_encode"
[281] "str_json_decode" "str_json_path_match"
[283] "str_len_bytes" "str_len_chars"
[285] "str_pad_end" "str_pad_start"
[287] "str_replace" "str_replace_all"
[289] "str_replace_many" "str_reverse"
[291] "str_slice" "str_split"
[293] "str_split_exact" "str_splitn"
[295] "str_starts_with" "str_strip_chars"
[297] "str_strip_chars_end" "str_strip_chars_start"
[299] "str_tail" "str_to_date"
[301] "str_to_datetime" "str_to_integer"
[303] "str_to_lowercase" "str_to_time"
[305] "str_to_titlecase" "str_to_uppercase"
[307] "str_zfill" "struct_field_by_name"
[309] "struct_rename_fields" "sub"
[311] "sum" "tail"
[313] "tan" "tanh"
[315] "timestamp" "to_physical"
[317] "top_k" "unique"
[319] "unique_counts" "unique_stable"
[321] "upper_bound" "value_counts"
[323] "var" "xor"

# public and private methods of each class When

Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test-expr_string.R
Original file line number Diff line number Diff line change
Expand Up @@ -884,3 +884,40 @@ test_that("str$find() works", {
test$select(lit = pl$col("s")$str$find("(?iAa", strict = TRUE, literal = TRUE))
)
})

test_that("$str$head() works", {
df = pl$DataFrame(
s = c("pear", NA, "papaya", "dragonfruit"),
n = c(3, 4, -2, -5)
)

expect_equal(
df$select(
s_head_5 = pl$col("s")$str$head(5),
s_head_n = pl$col("s")$str$head("n")
)$to_list(),
list(
s_head_5 = c("pear", NA, "papay", "drago"),
s_head_n = c("pea", NA, "papa", "dragon")
)
)
})


test_that("$str$tail() works", {
df = pl$DataFrame(
s = c("pear", NA, "papaya", "dragonfruit"),
n = c(3, 4, -2, -5)
)

expect_equal(
df$select(
s_tail_5 = pl$col("s")$str$tail(5),
s_tail_n = pl$col("s")$str$tail("n")
)$to_list(),
list(
s_tail_5 = c("pear", NA, "apaya", "fruit"),
s_tail_n = c("ear", NA, "paya", "nfruit")
)
)
})
Loading