diff --git a/polars/polars-core/src/chunked_array/ops/is_in.rs b/polars/polars-core/src/chunked_array/ops/is_in.rs index f09737cf201d..7c408f977ebf 100644 --- a/polars/polars-core/src/chunked_array/ops/is_in.rs +++ b/polars/polars-core/src/chunked_array/ops/is_in.rs @@ -81,8 +81,8 @@ where } _ => { // first make sure that the types are equal - let st = try_get_supertype(self.dtype(), other.dtype())?; if self.dtype() != other.dtype() { + let st = try_get_supertype(self.dtype(), other.dtype())?; let left = self.cast(&st)?; let right = other.cast(&st)?; return left.is_in(&right); @@ -328,6 +328,30 @@ impl IsIn for StructChunked { self.fields().len(), other.fields().len() ); + // first make sure that the types are equal + let self_dtypes: Vec<_> = self.fields().iter().map(|f| f.dtype()).collect(); + let other_dtypes: Vec<_> = other.fields().iter().map(|f| f.dtype()).collect(); + if self_dtypes != other_dtypes { + let self_names = self.fields().iter().map(|f| f.name()); + let other_names = other.fields().iter().map(|f| f.name()); + let supertypes = self_dtypes + .iter() + .zip(other_dtypes.iter()) + .map(|(dt1, dt2)| try_get_supertype(dt1, dt2)) + .collect::, _>>()?; + let self_supertype_fields = self_names + .zip(supertypes.iter()) + .map(|(name, st)| Field::new(name, st.clone())) + .collect(); + let self_super = self.cast(&DataType::Struct(self_supertype_fields))?; + let other_supertype_fields = other_names + .zip(supertypes.iter()) + .map(|(name, st)| Field::new(name, st.clone())) + .collect(); + let other_super = other.cast(&DataType::Struct(other_supertype_fields))?; + return self_super.is_in(&other_super); + } + let mut anyvalues = Vec::with_capacity(other.len() * other.fields().len()); // Safety: // the iterator is unsafe as the lifetime is tied to the iterator diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 7e9965508e95..a8cb14515510 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -836,9 +836,10 @@ def test_struct_unique_df() -> None: def test_struct_is_in() -> None: + # The dtype casts below test that struct is_in upcasts dtypes. s1 = ( pl.DataFrame({"x": [4, 3, 4, 9], "y": [0, 4, 6, 2]}) - .select(pl.struct(["x", "y"])) + .select(pl.struct(schema={"x": pl.Int8, "y": pl.Float32})) .to_series() ) s2 = (