Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): increment seed between samples #9694

Merged
merged 3 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ mod list;
mod log;
mod nan;
mod pow;
#[cfg(feature = "random")]
mod random;
#[cfg(feature = "arange")]
mod range;
#[cfg(all(feature = "rolling_window", feature = "moment"))]
Expand All @@ -51,6 +53,8 @@ mod trigonometry;
mod unique;

use std::fmt::{Display, Formatter};
#[cfg(feature = "random")]
use std::sync::atomic::AtomicU64;

#[cfg(feature = "dtype-array")]
pub(super) use array::ArrayFunction;
Expand All @@ -59,6 +63,8 @@ pub(crate) use correlation::CorrelationMethod;
pub(crate) use fused::FusedOperator;
pub(super) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "random")]
pub(crate) use random::RandomMethod;
use schema::FieldsMapper;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -193,6 +199,14 @@ pub enum FunctionExpr {
ddof: u8,
},
ToPhysical,
#[cfg(feature = "random")]
Random {
method: random::RandomMethod,
#[cfg_attr(feature = "serde", serde(skip))]
atomic_seed: Option<SpecialEq<Arc<AtomicU64>>>,
seed: Option<u64>,
fixed_seed: bool,
},
}

impl Display for FunctionExpr {
Expand Down Expand Up @@ -288,6 +302,8 @@ impl Display for FunctionExpr {
ConcatExpr(_) => "concat_expr",
Correlation { method, .. } => return Display::fmt(method, f),
ToPhysical => "to_physical",
#[cfg(feature = "random")]
Random { method, .. } => method.into(),
};
write!(f, "{s}")
}
Expand Down Expand Up @@ -515,6 +531,19 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk),
Correlation { method, ddof } => map_as_slice!(correlation::corr, ddof, method),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Random {
method,
seed,
atomic_seed,
fixed_seed,
} => map!(
random::random,
method,
atomic_seed.as_deref(),
seed,
fixed_seed
),
}
}
}
Expand Down
55 changes: 55 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::sync::atomic::Ordering;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;

use super::*;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Copy, Clone, PartialEq, Debug, IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub enum RandomMethod {
Shuffle,
SampleN {
n: usize,
with_replacement: bool,
shuffle: bool,
},
SampleFrac {
frac: f64,
with_replacement: bool,
shuffle: bool,
},
}

pub(super) fn random(
s: &Series,
method: RandomMethod,
atomic_seed: Option<&Arc<AtomicU64>>,
seed: Option<u64>,
fixed_seed: bool,
) -> PolarsResult<Series> {
let seed = if fixed_seed {
seed
} else {
// ensure seeds differ between groupby groups
// otherwise all groups would be sampled the same
atomic_seed
.as_ref()
.map(|atomic| atomic.fetch_add(1, Ordering::Relaxed))
};
match method {
RandomMethod::Shuffle => Ok(s.shuffle(seed)),
RandomMethod::SampleFrac {
frac,
with_replacement,
shuffle,
} => s.sample_frac(frac, with_replacement, shuffle, seed),
RandomMethod::SampleN {
n,
with_replacement,
shuffle,
} => s.sample_n(n, with_replacement, shuffle, seed),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ impl FunctionExpr {
ConcatExpr(_) => mapper.map_to_supertype(),
Correlation { .. } => mapper.map_to_float_dtype(),
ToPhysical => mapper.to_physical_type(),
#[cfg(feature = "random")]
Random { .. } => mapper.with_same_dtype(),
}
}
}
Expand Down
41 changes: 2 additions & 39 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub(crate) mod names;
mod options;
#[cfg(all(feature = "python", feature = "serde"))]
pub mod python_udf;
#[cfg(feature = "random")]
mod random;
mod selector;
#[cfg(feature = "strings")]
pub mod string;
Expand Down Expand Up @@ -1559,45 +1561,6 @@ impl Expr {
.with_fmt("reshape")
}

#[cfg(feature = "random")]
pub fn shuffle(self, seed: Option<u64>) -> Self {
self.apply(move |s| Ok(Some(s.shuffle(seed))), GetOutput::same_type())
.with_fmt("shuffle")
}

#[cfg(feature = "random")]
pub fn sample_n(
self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| s.sample_n(n, with_replacement, shuffle, seed).map(Some),
GetOutput::same_type(),
)
.with_fmt("sample_n")
}

#[cfg(feature = "random")]
pub fn sample_frac(
self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| {
s.sample_frac(frac, with_replacement, shuffle, seed)
.map(Some)
},
GetOutput::same_type(),
)
.with_fmt("sample_frac")
}

#[cfg(feature = "ewma")]
pub fn ewm_mean(self, options: EWMOptions) -> Self {
use DataType::*;
Expand Down
58 changes: 58 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::sync::atomic::AtomicU64;

use super::*;

fn get_atomic_seed(seed: Option<u64>) -> Option<SpecialEq<Arc<AtomicU64>>> {
seed.map(|v| SpecialEq::new(Arc::new(AtomicU64::new(v))))
}

impl Expr {
pub fn shuffle(self, seed: Option<u64>, fixed_seed: bool) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::Shuffle,
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}

pub fn sample_n(
self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::SampleN {
n,
with_replacement,
shuffle,
},
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}

pub fn sample_frac(
self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::SampleFrac {
frac,
with_replacement,
shuffle,
},
atomic_seed: get_atomic_seed(seed),
seed,
fixed_seed,
})
}
}
20 changes: 16 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7497,7 +7497,7 @@ def reshape(self, dimensions: tuple[int, ...]) -> Self:
"""
return self._from_pyexpr(self._pyexpr.reshape(dimensions))

def shuffle(self, seed: int | None = None) -> Self:
def shuffle(self, seed: int | None = None, fixed_seed: bool = False) -> Self:
"""
Shuffle the contents of this expression.

Expand All @@ -7506,6 +7506,10 @@ def shuffle(self, seed: int | None = None) -> Self:
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
fixed_seed
If True, The seed will not be incremented between draws.
This can make output predictable because draw ordering can
change due to threads being scheduled in a different order.

Examples
--------
Expand All @@ -7523,9 +7527,10 @@ def shuffle(self, seed: int | None = None) -> Self:
└─────┘

"""
# we seed from python so that we respect ``random.seed``
if seed is None:
seed = random.randint(0, 10000)
return self._from_pyexpr(self._pyexpr.shuffle(seed))
return self._from_pyexpr(self._pyexpr.shuffle(seed, fixed_seed))

@deprecated_alias(frac="fraction")
def sample(
Expand All @@ -7536,6 +7541,7 @@ def sample(
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
fixed_seed: bool = False,
) -> Self:
"""
Sample from this expression.
Expand All @@ -7554,6 +7560,10 @@ def sample(
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
fixed_seed
If True, The seed will not be incremented between draws.
This can make output predictable because draw ordering can
change due to threads being scheduled in a different order.

Examples
--------
Expand All @@ -7579,13 +7589,15 @@ def sample(

if fraction is not None:
return self._from_pyexpr(
self._pyexpr.sample_frac(fraction, with_replacement, shuffle, seed)
self._pyexpr.sample_frac(
fraction, with_replacement, shuffle, seed, fixed_seed
)
)

if n is None:
n = 1
return self._from_pyexpr(
self._pyexpr.sample_n(n, with_replacement, shuffle, seed)
self._pyexpr.sample_n(n, with_replacement, shuffle, seed, fixed_seed)
)

def ewm_mean(
Expand Down
9 changes: 9 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5603,6 +5603,15 @@ def shuffle(self, seed: int | None = None) -> Series:
]

"""
return (
self.to_frame()
.select(
F.col(self.name).shuffle(
seed=seed,
)
)
.to_series()
)

def ewm_mean(
self,
Expand Down
21 changes: 16 additions & 5 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,27 +972,38 @@ impl PyExpr {
self.inner.clone().to_physical().into()
}

fn shuffle(&self, seed: Option<u64>) -> Self {
self.inner.clone().shuffle(seed).into()
#[pyo3(signature = (seed, fixed_seed))]
fn shuffle(&self, seed: Option<u64>, fixed_seed: bool) -> Self {
self.inner.clone().shuffle(seed, fixed_seed).into()
}

fn sample_n(&self, n: usize, with_replacement: bool, shuffle: bool, seed: Option<u64>) -> Self {
#[pyo3(signature = (n, with_replacement, shuffle, seed, fixed_seed))]
fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.inner
.clone()
.sample_n(n, with_replacement, shuffle, seed)
.sample_n(n, with_replacement, shuffle, seed, fixed_seed)
.into()
}

#[pyo3(signature = (frac, with_replacement, shuffle, seed, fixed_seed))]
fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
fixed_seed: bool,
) -> Self {
self.inner
.clone()
.sample_frac(frac, with_replacement, shuffle, seed)
.sample_frac(frac, with_replacement, shuffle, seed, fixed_seed)
.into()
}

Expand Down
Loading