From 24fccc45a45b7d5b4992a6ac4583f3875944af7e Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Wed, 22 Jan 2025 10:14:44 -0500 Subject: [PATCH] fix: Incorrect `Decimal` value for `fill_null(strategy="one")` (#20844) --- .../polars-core/src/chunked_array/ops/fill_null.rs | 9 +++++++++ py-polars/tests/unit/datatypes/test_decimal.py | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index ae8772373a33..8488685de942 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -93,6 +93,15 @@ impl Series { fill_backward_gather(self) }, FillNullStrategy::Backward(Some(limit)) => fill_backward_gather_limit(self, limit), + #[cfg(feature = "dtype-decimal")] + FillNullStrategy::One if self.dtype().is_decimal() => { + let ca = self.decimal().unwrap(); + let precision = ca.precision(); + let scale = ca.scale(); + let fill_value = 10i128.pow(scale as u32); + let phys = ca.as_ref().fill_null_with_values(fill_value)?; + Ok(phys.into_decimal_unchecked(precision, scale).into_series()) + }, _ => { let logical_type = self.dtype(); let s = self.to_physical_repr(); diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 7347fbf539c5..14cf6564b91a 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -638,3 +638,16 @@ def test_shift_over_12957() -> None: ) assert result["x"].to_list() == [None, D("1.1"), None, D("2.2")] assert result["y"].to_list() == [None, 1, None, 2] + + +def test_fill_null() -> None: + s = pl.Series("a", [D("1.2"), None, D("1.4")]) + + assert s.fill_null(D("0.0")).to_list() == [D("1.2"), D("0.0"), D("1.4")] + assert s.fill_null(strategy="zero").to_list() == [D("1.2"), D("0.0"), D("1.4")] + assert s.fill_null(strategy="max").to_list() == [D("1.2"), D("1.4"), D("1.4")] + assert s.fill_null(strategy="min").to_list() == [D("1.2"), D("1.2"), D("1.4")] + assert s.fill_null(strategy="one").to_list() == [D("1.2"), D("1.0"), D("1.4")] + assert s.fill_null(strategy="forward").to_list() == [D("1.2"), D("1.2"), D("1.4")] + assert s.fill_null(strategy="backward").to_list() == [D("1.2"), D("1.4"), D("1.4")] + assert s.fill_null(strategy="mean").to_list() == [D("1.2"), D("1.3"), D("1.4")]