Skip to content

Commit

Permalink
fix(rust, python): don't run hstack checks when using cached names (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 4, 2023
1 parent df5ae40 commit bc1c6ec
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 46 deletions.
14 changes: 9 additions & 5 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,10 +736,14 @@ impl DataFrame {
self.columns.is_empty()
}

pub(crate) fn hstack_mut_no_checks(&mut self, columns: &[Series]) -> &mut Self {
for col in columns {
self.columns.push(col.clone());
}
/// Add columns horizontally.
///
/// # Safety
/// The caller must ensure:
/// - the length of all [`Series`] is equal to the height of this [`DataFrame`]
/// - the columns names are unique
pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Series]) -> &mut Self {
self.columns.extend_from_slice(columns);
self
}

Expand Down Expand Up @@ -774,7 +778,7 @@ impl DataFrame {
);
}
drop(names);
Ok(self.hstack_mut_no_checks(columns))
Ok(unsafe { self.hstack_mut_unchecked(columns) })
}

/// Add multiple `Series` to a `DataFrame`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,36 @@ impl GenericJoinProbe {
unsafe { Ok(self.current_rows.borrow_array()) }
}

fn finish_join(
&mut self,
mut left_df: DataFrame,
right_df: DataFrame,
) -> PolarsResult<DataFrame> {
Ok(match &self.output_names {
None => {
let out = _finish_join(left_df, right_df, Some(self.suffix.as_ref()))?;
self.output_names = Some(out.get_column_names_owned());
out
}
Some(names) => unsafe {
// safety:
// if we have duplicate names, we overwrite
// them in the next snippet
left_df
.get_columns_mut()
.extend_from_slice(right_df.get_columns());
left_df
.get_columns_mut()
.iter_mut()
.zip(names)
.for_each(|(s, name)| {
s.rename(name);
});
left_df
},
})
}

fn execute_left(
&mut self,
context: &PExecutionContext,
Expand Down Expand Up @@ -196,30 +226,11 @@ impl GenericJoinProbe {
}
let right_df = self.df_a.as_ref();

let mut left_df = unsafe { chunk.data._take_unchecked_slice(&self.join_tuples_b, false) };
let left_df = unsafe { chunk.data._take_unchecked_slice(&self.join_tuples_b, false) };
let right_df =
unsafe { right_df._take_opt_chunked_unchecked_seq(&self.join_tuples_a_left_join) };

let out = match &self.output_names {
None => {
let out = _finish_join(left_df, right_df, Some(self.suffix.as_ref()))?;
self.output_names = Some(out.get_column_names_owned());
out
}
Some(names) => unsafe {
left_df
.get_columns_mut()
.extend_from_slice(right_df.get_columns());
left_df
.get_columns_mut()
.iter_mut()
.zip(names)
.for_each(|(s, name)| {
s.rename(name);
});
left_df
},
};
let out = self.finish_join(left_df, right_df)?;

// clear memory
self.join_columns.clear();
Expand Down Expand Up @@ -279,30 +290,12 @@ impl GenericJoinProbe {
df._take_unchecked_slice(&self.join_tuples_b, false)
};

let (mut a, b) = if self.swapped_or_left {
let (a, b) = if self.swapped_or_left {
(right_df, left_df)
} else {
(left_df, right_df)
};
let out = match &self.output_names {
None => {
let out = _finish_join(a, b, Some(self.suffix.as_ref()))?;
self.output_names = Some(out.get_column_names_owned());
out
}
Some(names) => {
a.hstack_mut(b.get_columns()).unwrap();
unsafe {
a.get_columns_mut()
.iter_mut()
.zip(names)
.for_each(|(s, name)| {
s.rename(name);
});
}
a
}
};
let out = self.finish_join(a, b)?;

// clear memory
self.join_columns.clear();
Expand Down

0 comments on commit bc1c6ec

Please sign in to comment.