Skip to content

Commit

Permalink
fix(rust, python): increment seed between samples (#9694)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 4, 2023
1 parent 227c850 commit d40e403
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 140 deletions.
29 changes: 29 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ mod list;
mod log;
mod nan;
mod pow;
#[cfg(feature = "random")]
mod random;
#[cfg(feature = "arange")]
mod range;
#[cfg(all(feature = "rolling_window", feature = "moment"))]
Expand All @@ -51,6 +53,8 @@ mod trigonometry;
mod unique;

use std::fmt::{Display, Formatter};
#[cfg(feature = "random")]
use std::sync::atomic::AtomicU64;

#[cfg(feature = "dtype-array")]
pub(super) use array::ArrayFunction;
Expand All @@ -59,6 +63,8 @@ pub(crate) use correlation::CorrelationMethod;
pub(crate) use fused::FusedOperator;
pub(super) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "random")]
pub(crate) use random::RandomMethod;
use schema::FieldsMapper;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -193,6 +199,14 @@ pub enum FunctionExpr {
ddof: u8,
},
ToPhysical,
#[cfg(feature = "random")]
Random {
method: random::RandomMethod,
#[cfg_attr(feature = "serde", serde(skip))]
atomic_seed: Option<SpecialEq<Arc<AtomicU64>>>,
seed: Option<u64>,
fixed_seed: bool,
},
}

impl Display for FunctionExpr {
Expand Down Expand Up @@ -288,6 +302,8 @@ impl Display for FunctionExpr {
ConcatExpr(_) => "concat_expr",
Correlation { method, .. } => return Display::fmt(method, f),
ToPhysical => "to_physical",
#[cfg(feature = "random")]
Random { method, .. } => method.into(),
};
write!(f, "{s}")
}
Expand Down Expand Up @@ -515,6 +531,19 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk),
Correlation { method, ddof } => map_as_slice!(correlation::corr, ddof, method),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Random {
method,
seed,
atomic_seed,
fixed_seed,
} => map!(
random::random,
method,
atomic_seed.as_deref(),
seed,
fixed_seed
),
}
}
}
Expand Down
55 changes: 55 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::sync::atomic::Ordering;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;

use super::*;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Copy, Clone, PartialEq, Debug, IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub enum RandomMethod {
Shuffle,
SampleN {
n: usize,
with_replacement: bool,
shuffle: bool,
},
SampleFrac {
frac: f64,
with_replacement: bool,
shuffle: bool,
},
}

pub(super) fn random(
s: &Series,
method: RandomMethod,
atomic_seed: Option<&Arc<AtomicU64>>,
seed: Option<u64>,
fixed_seed: bool,
) -> PolarsResult<Series> {
let seed = if fixed_seed {
seed
} else {
// ensure seeds differ between groupby groups
// otherwise all groups would be sampled the same
atomic_seed
.as_ref()
.map(|atomic| atomic.fetch_add(1, Ordering::Relaxed))
};
match method {
RandomMethod::Shuffle => Ok(s.shuffle(seed)),
RandomMethod::SampleFrac {
frac,
with_replacement,
shuffle,
} => s.sample_frac(frac, with_replacement, shuffle, seed),
RandomMethod::SampleN {
n,
with_replacement,
shuffle,
} => s.sample_n(n, with_replacement, shuffle, seed),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ impl FunctionExpr {
ConcatExpr(_) => mapper.map_to_supertype(),
Correlation { .. } => mapper.map_to_float_dtype(),
ToPhysical => mapper.to_physical_type(),
#[cfg(feature = "random")]
Random { .. } => mapper.with_same_dtype(),
}
}
}
Expand Down
41 changes: 2 additions & 39 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub(crate) mod names;
mod options;
#[cfg(all(feature = "python", feature = "serde"))]
pub mod python_udf;
#[cfg(feature = "random")]
mod random;
mod selector;
#[cfg(feature = "strings")]
pub mod string;
Expand Down Expand Up @@ -1559,45 +1561,6 @@ impl Expr {
.with_fmt("reshape")
}

#[cfg(feature = "random")]
pub fn shuffle(self, seed: Option<u64>) -> Self {
self.apply(move |s| Ok(Some(s.shuffle(seed))), GetOutput::same_type())
.with_fmt("shuffle")
}

#[cfg(feature = "random")]
pub fn sample_n(
self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| s.sample_n(n, with_replacement, shuffle, seed).map(Some),
GetOutput::same_type(),
)
.with_fmt("sample_n")
}

#[cfg(feature = "random")]
pub fn sample_frac(
self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| {
s.sample_frac(frac, with_replacement, shuffle, seed)
.map(Some)
},
GetOutput::same_type(),
)
.with_fmt("sample_frac")
}

#[cfg(feature = "ewma")]
pub fn ewm_mean(self, options: EWMOptions) -> Self {
use DataType::*;
Expand Down
58 changes: 58 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::sync::atomic::AtomicU64;

use super::*;

fn get_atomic_seed(seed: Option<u64>) -> Option<SpecialEq<Arc<AtomicU64>>> {
seed.map(|v| SpecialEq::new(Arc::new(AtomicU64::new(v))))
}

impl Expr {
pub fn shuffle(self, seed: Option<u64>, fixed_seed: bool) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::Shuffle,
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}

pub fn sample_n(
self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::SampleN {
n,
with_replacement,
shuffle,
},
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}

pub fn sample_frac(
self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::SampleFrac {
frac,
with_replacement,
shuffle,
},
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}
}
20 changes: 16 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7497,7 +7497,7 @@ def reshape(self, dimensions: tuple[int, ...]) -> Self:
"""
return self._from_pyexpr(self._pyexpr.reshape(dimensions))

def shuffle(self, seed: int | None = None) -> Self:
def shuffle(self, seed: int | None = None, fixed_seed: bool = False) -> Self:
"""
Shuffle the contents of this expression.
Expand All @@ -7506,6 +7506,10 @@ def shuffle(self, seed: int | None = None) -> Self:
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
fixed_seed
If True, The seed will not be incremented between draws.
This can make output predictable because draw ordering can
change due to threads being scheduled in a different order.
Examples
--------
Expand All @@ -7523,9 +7527,10 @@ def shuffle(self, seed: int | None = None) -> Self:
└─────┘
"""
# we seed from python so that we respect ``random.seed``
if seed is None:
seed = random.randint(0, 10000)
return self._from_pyexpr(self._pyexpr.shuffle(seed))
return self._from_pyexpr(self._pyexpr.shuffle(seed, fixed_seed))

@deprecated_alias(frac="fraction")
def sample(
Expand All @@ -7536,6 +7541,7 @@ def sample(
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
fixed_seed: bool = False,
) -> Self:
"""
Sample from this expression.
Expand All @@ -7554,6 +7560,10 @@ def sample(
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
fixed_seed
If True, The seed will not be incremented between draws.
This can make output predictable because draw ordering can
change due to threads being scheduled in a different order.
Examples
--------
Expand All @@ -7579,13 +7589,15 @@ def sample(

if fraction is not None:
return self._from_pyexpr(
self._pyexpr.sample_frac(fraction, with_replacement, shuffle, seed)
self._pyexpr.sample_frac(
fraction, with_replacement, shuffle, seed, fixed_seed
)
)

if n is None:
n = 1
return self._from_pyexpr(
self._pyexpr.sample_n(n, with_replacement, shuffle, seed)
self._pyexpr.sample_n(n, with_replacement, shuffle, seed, fixed_seed)
)

def ewm_mean(
Expand Down
9 changes: 9 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,15 @@ def shuffle(self, seed: int | None = None) -> Series:
]
"""
return (
self.to_frame()
.select(
F.col(self.name).shuffle(
seed=seed,
)
)
.to_series()
)

def ewm_mean(
self,
Expand Down
21 changes: 16 additions & 5 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,27 +972,38 @@ impl PyExpr {
self.inner.clone().to_physical().into()
}

fn shuffle(&self, seed: Option<u64>) -> Self {
self.inner.clone().shuffle(seed).into()
#[pyo3(signature = (seed, fixed_seed))]
fn shuffle(&self, seed: Option<u64>, fixed_seed: bool) -> Self {
self.inner.clone().shuffle(seed, fixed_seed).into()
}

fn sample_n(&self, n: usize, with_replacement: bool, shuffle: bool, seed: Option<u64>) -> Self {
#[pyo3(signature = (n, with_replacement, shuffle, seed, fixed_seed))]
fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.inner
.clone()
.sample_n(n, with_replacement, shuffle, seed)
.sample_n(n, with_replacement, shuffle, seed, fixed_seed)
.into()
}

#[pyo3(signature = (frac, with_replacement, shuffle, seed, fixed_seed))]
fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.inner
.clone()
.sample_frac(frac, with_replacement, shuffle, seed)
.sample_frac(frac, with_replacement, shuffle, seed, fixed_seed)
.into()
}

Expand Down
Loading

0 comments on commit d40e403

Please sign in to comment.