From a0d94f10c6fcb8149e34e59b6bd49138ef89698f Mon Sep 17 00:00:00 2001 From: Mfon Ekpo <58835748+mfonekpo@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:46:02 +0100 Subject: [PATCH] Feat: `Series.shift` Pyarrow Backend Implementation (#590) --- narwhals/_arrow/expr.py | 3 +++ narwhals/_arrow/series.py | 12 +++++++++ tests/expr_and_series/shift_test.py | 41 ++++++++++++++++++----------- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 844901da5..a539040f0 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -203,6 +203,9 @@ def sum(self) -> Self: def drop_nulls(self) -> Self: return reuse_series_implementation(self, "drop_nulls") + def shift(self, n: int) -> Self: + return reuse_series_implementation(self, "shift", n) + def alias(self, name: str) -> Self: # Define this one manually, so that we can # override `output_names` and not increase depth diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index ff3ee94c6..3c13be2db 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -221,6 +221,18 @@ def drop_nulls(self) -> ArrowSeries: pc = get_pyarrow_compute() return self._from_native_series(pc.drop_null(self._native_series)) + def shift(self, n: int) -> Self: + pa = get_pyarrow() + ca = self._native_series + + if n > 0: + result = pa.concat_arrays([pa.nulls(n, ca.type), *ca[:-n].chunks]) + elif n < 0: + result = pa.concat_arrays([*ca[-n:].chunks, pa.nulls(-n, ca.type)]) + else: + result = ca + return self._from_native_series(result) + def std(self, ddof: int = 1) -> int: pc = get_pyarrow_compute() return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return] diff --git a/tests/expr_and_series/shift_test.py b/tests/expr_and_series/shift_test.py index 9460357d6..02dbed6b0 100644 --- a/tests/expr_and_series/shift_test.py +++ b/tests/expr_and_series/shift_test.py @@ -1,6 +1,6 @@ from typing import Any -import pytest +import pyarrow as pa import narwhals.stable.v1 as nw from tests.utils import compare_dicts @@ -13,10 +13,7 @@ } -def test_shift(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_shift(constructor: Any) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a", "b", "c").shift(2)).filter(nw.col("i") > 1) expected = { @@ -28,21 +25,35 @@ def test_shift(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_shift_series(request: Any, constructor_eager: Any) -> None: - if "pyarrow_table" in str(constructor_eager): - request.applymarker(pytest.mark.xfail) - +def test_shift_series(constructor_eager: Any) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) + result = df.with_columns( + df["a"].shift(2), + df["b"].shift(2), + df["c"].shift(2), + ).filter(nw.col("i") > 1) expected = { "i": [2, 3, 4], "a": [0, 1, 2], "b": [1, 2, 3], "c": [5, 4, 3], } - result = df.select( - df["i"], - df["a"].shift(2), - df["b"].shift(2), - df["c"].shift(2), - ).filter(nw.col("i") > 1) + compare_dicts(result, expected) + + +def test_shift_multi_chunk_pyarrow() -> None: + tbl = pa.table({"a": [1, 2, 3]}) + tbl = pa.concat_tables([tbl, tbl, tbl]) + df = nw.from_native(tbl, eager_only=True) + + result = df.select(nw.col("a").shift(1)) + expected = {"a": [None, 1, 2, 3, 1, 2, 3, 1, 2]} + compare_dicts(result, expected) + + result = df.select(nw.col("a").shift(-1)) + expected = {"a": [2, 3, 1, 2, 3, 1, 2, 3, None]} + compare_dicts(result, expected) + + result = df.select(nw.col("a").shift(0)) + expected = {"a": [1, 2, 3, 1, 2, 3, 1, 2, 3]} compare_dicts(result, expected)