From 1709691e02d3eeec36c4f022881a8af6c92cd3c9 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 21 Jun 2023 14:47:20 +0200 Subject: [PATCH] feat(rust, python): add infer schema len to json_extract (#9478) --- polars/Cargo.toml | 7 ++++- polars/polars-json/src/ndjson/file.rs | 3 +- polars/polars-lazy/Cargo.toml | 1 + polars/polars-lazy/polars-plan/Cargo.toml | 1 + .../polars-plan/src/dsl/function_expr/mod.rs | 5 ++++ .../src/dsl/function_expr/schema.rs | 6 ++++ .../src/dsl/function_expr/strings.rs | 19 +++++++++++++ .../polars-lazy/polars-plan/src/dsl/string.rs | 9 ++++++ .../src/chunked_array/strings/json_path.rs | 23 ++++++++++----- py-polars/polars/expr/string.py | 9 ++++-- py-polars/polars/series/string.py | 7 ++++- py-polars/src/expr/string.rs | 28 ++++++------------- 12 files changed, 87 insertions(+), 31 deletions(-) diff --git a/polars/Cargo.toml b/polars/Cargo.toml index 0288dac51773..53dd46d1bbdc 100644 --- a/polars/Cargo.toml +++ b/polars/Cargo.toml @@ -100,7 +100,12 @@ decompress = ["polars-io/decompress"] decompress-fast = ["polars-io/decompress-fast"] mode = ["polars-core/mode", "polars-lazy/mode"] take_opt_iter = ["polars-core/take_opt_iter"] -extract_jsonpath = ["polars-core/strings", "polars-ops/extract_jsonpath", "polars-ops/strings"] +extract_jsonpath = [ + "polars-core/strings", + "polars-ops/extract_jsonpath", + "polars-ops/strings", + "polars-lazy/extract_jsonpath", +] string_encoding = ["polars-ops/string_encoding", "polars-core/strings"] binary_encoding = ["polars-ops/binary_encoding"] groupby_list = ["polars-core/groupby_list"] diff --git a/polars/polars-json/src/ndjson/file.rs b/polars/polars-json/src/ndjson/file.rs index 2ae02dba97c1..83ca6c93e22c 100644 --- a/polars/polars-json/src/ndjson/file.rs +++ b/polars/polars-json/src/ndjson/file.rs @@ -144,5 +144,6 @@ pub fn infer_iter>(rows: impl Iterator) -> PolarsResult< } let v: Vec<&DataType> = data_types.iter().collect(); - Ok(crate::json::infer_schema::coerce_data_type(&v)) + dbg!(&v); + dbg!(Ok(crate::json::infer_schema::coerce_data_type(&v))) } diff --git a/polars/polars-lazy/Cargo.toml b/polars/polars-lazy/Cargo.toml index f13ff9342db3..fcd02c0b45e9 100644 --- a/polars/polars-lazy/Cargo.toml +++ b/polars/polars-lazy/Cargo.toml @@ -71,6 +71,7 @@ list_take = ["polars-ops/list_take", "polars-plan/list_take"] list_count = ["polars-ops/list_count", "polars-plan/list_count"] true_div = ["polars-plan/true_div"] +extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"] # operations approx_unique = ["polars-plan/approx_unique"] diff --git a/polars/polars-lazy/polars-plan/Cargo.toml b/polars/polars-lazy/polars-plan/Cargo.toml index 3a11dd0a148e..63fb519e6ed3 100644 --- a/polars/polars-lazy/polars-plan/Cargo.toml +++ b/polars/polars-lazy/polars-plan/Cargo.toml @@ -68,6 +68,7 @@ timezones = ["chrono-tz", "polars-time/timezones", "polars-core/timezones", "reg binary_encoding = ["polars-ops/binary_encoding"] true_div = [] nightly = ["polars-utils/nightly", "polars-ops/nightly"] +extract_jsonpath = [] # operations approx_unique = ["polars-ops/approx_unique"] diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs index 69a23522a973..e8db2c9ddb51 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs @@ -556,6 +556,11 @@ impl From for SpecialEq> { Explode => map!(strings::explode), #[cfg(feature = "dtype-decimal")] ToDecimal(infer_len) => map!(strings::to_decimal, infer_len), + #[cfg(feature = "extract_jsonpath")] + JsonExtract { + dtype, + infer_schema_len, + } => map!(strings::json_extract, dtype.clone(), infer_schema_len), } } } diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs index c5976d797972..b29173db914a 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs @@ -339,4 +339,10 @@ impl<'a> FieldsMapper<'a> { } Ok(first) } + + #[cfg(feature = "extract_jsonpath")] + pub(super) fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { + let dtype = dtype.unwrap_or(DataType::Unknown); + self.with_dtype(dtype) + } } diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs index 6eb9dd5e3c76..da1d7350a76a 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs @@ -68,6 +68,11 @@ pub enum StringFunction { Explode, #[cfg(feature = "dtype-decimal")] ToDecimal(usize), + #[cfg(feature = "extract_jsonpath")] + JsonExtract { + dtype: Option, + infer_schema_len: Option, + }, } impl StringFunction { @@ -98,6 +103,8 @@ impl StringFunction { Explode => mapper.with_same_dtype(), #[cfg(feature = "dtype-decimal")] ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)), + #[cfg(feature = "extract_jsonpath")] + JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), } } } @@ -139,6 +146,8 @@ impl Display for StringFunction { StringFunction::Explode => "explode", #[cfg(feature = "dtype-decimal")] StringFunction::ToDecimal(_) => "to_decimal", + #[cfg(feature = "extract_jsonpath")] + StringFunction::JsonExtract { .. } => "json_extract", }; write!(f, "str.{s}") @@ -675,3 +684,13 @@ pub(super) fn to_decimal(s: &Series, infer_len: usize) -> PolarsResult { let ca = s.utf8()?; ca.to_decimal(infer_len) } + +#[cfg(feature = "extract_jsonpath")] +pub(super) fn json_extract( + s: &Series, + dtype: Option, + infer_schema_len: Option, +) -> PolarsResult { + let ca = s.utf8()?; + ca.json_extract(dtype, infer_schema_len) +} diff --git a/polars/polars-lazy/polars-plan/src/dsl/string.rs b/polars/polars-lazy/polars-plan/src/dsl/string.rs index 63344bc9c2b8..f9eb6acfd61b 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/string.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/string.rs @@ -476,4 +476,13 @@ impl StringNameSpace { self.0 .apply_private(FunctionExpr::StringExpr(StringFunction::Explode)) } + + #[cfg(feature = "extract_jsonpath")] + pub fn json_extract(self, dtype: Option, infer_schema_len: Option) -> Expr { + self.0 + .map_private(FunctionExpr::StringExpr(StringFunction::JsonExtract { + dtype, + infer_schema_len, + })) + } } diff --git a/polars/polars-ops/src/chunked_array/strings/json_path.rs b/polars/polars-ops/src/chunked_array/strings/json_path.rs index 3aded4542c6f..0f10895c689b 100644 --- a/polars/polars-ops/src/chunked_array/strings/json_path.rs +++ b/polars/polars-ops/src/chunked_array/strings/json_path.rs @@ -64,11 +64,15 @@ pub trait Utf8JsonPathImpl: AsUtf8 { } /// Extracts a typed-JSON value for each row in the Utf8Chunked - fn json_extract(&self, dtype: Option) -> PolarsResult { + fn json_extract( + &self, + dtype: Option, + infer_schema_len: Option, + ) -> PolarsResult { let ca = self.as_utf8(); let dtype = match dtype { Some(dt) => dt, - None => ca.json_infer(None)?, + None => ca.json_infer(infer_schema_len)?, }; let buf_size = ca.get_values_size() + ca.null_count() * "null".len(); @@ -92,9 +96,14 @@ pub trait Utf8JsonPathImpl: AsUtf8 { .apply_on_opt(|opt_s| opt_s.and_then(|s| select_json(&pat, s)))) } - fn json_path_extract(&self, json_path: &str, dtype: Option) -> PolarsResult { + fn json_path_extract( + &self, + json_path: &str, + dtype: Option, + infer_schema_len: Option, + ) -> PolarsResult { let selected_json = self.as_utf8().json_path_select(json_path)?; - selected_json.json_extract(dtype) + selected_json.json_extract(dtype, infer_schema_len) } } @@ -178,11 +187,11 @@ mod tests { let expected_dtype = expected_series.dtype().clone(); assert!(ca - .json_extract(None) + .json_extract(None, None) .unwrap() .series_equal_missing(&expected_series)); assert!(ca - .json_extract(Some(expected_dtype)) + .json_extract(Some(expected_dtype), None) .unwrap() .series_equal_missing(&expected_series)); } @@ -253,7 +262,7 @@ mod tests { ); assert!(ca - .json_path_extract("$.b[:].c", None) + .json_path_extract("$.b[:].c", None, None) .unwrap() .into_series() .series_equal_missing(&c_series)); diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 8a95149d0ee3..48beb822f2be 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -942,7 +942,9 @@ def starts_with(self, prefix: str | Expr) -> Expr: prefix = parse_as_expression(prefix, str_as_lit=True) return wrap_expr(self._pyexpr.str_starts_with(prefix)) - def json_extract(self, dtype: PolarsDataType | None = None) -> Expr: + def json_extract( + self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + ) -> Expr: """ Parse string values as JSON. @@ -953,6 +955,9 @@ def json_extract(self, dtype: PolarsDataType | None = None) -> Expr: dtype The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. + infer_schema_length + How many rows to parse to determine the schema. + If ``None`` all rows are used. Examples -------- @@ -980,7 +985,7 @@ def json_extract(self, dtype: PolarsDataType | None = None) -> Expr: """ if dtype is not None: dtype = py_type_to_dtype(dtype) - return wrap_expr(self._pyexpr.str_json_extract(dtype)) + return wrap_expr(self._pyexpr.str_json_extract(dtype, infer_schema_length)) def json_path_match(self, json_path: str) -> Expr: """ diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 1ae85216f03b..7826cf5a9e31 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -559,7 +559,9 @@ def encode(self, encoding: TransferEncoding) -> Series: """ - def json_extract(self, dtype: PolarsDataType | None = None) -> Series: + def json_extract( + self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + ) -> Series: """ Parse string values as JSON. @@ -570,6 +572,9 @@ def json_extract(self, dtype: PolarsDataType | None = None) -> Series: dtype The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. + infer_schema_length + How many rows to parse to determine the schema. + If ``None`` all rows are used. Examples -------- diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index 0ce6e7aa15b0..033d686f6b1e 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -219,26 +219,16 @@ impl PyExpr { } #[cfg(feature = "extract_jsonpath")] - fn str_json_extract(&self, dtype: Option>) -> Self { + fn str_json_extract( + &self, + dtype: Option>, + infer_schema_len: Option, + ) -> Self { let dtype = dtype.map(|wrap| wrap.0); - - let output_type = match dtype.clone() { - Some(dtype) => GetOutput::from_type(dtype), - None => GetOutput::from_type(DataType::Unknown), - }; - - let function = move |s: Series| { - let ca = s.utf8()?; - match ca.json_extract(dtype.clone()) { - Ok(ca) => Ok(Some(ca.into_series())), - Err(e) => Err(PolarsError::ComputeError(format!("{e:?}").into())), - } - }; - - self.clone() - .inner - .map(function, output_type) - .with_fmt("str.json_extract") + self.inner + .clone() + .str() + .json_extract(dtype, infer_schema_len) .into() }