Skip to content

Commit

Permalink
speed up and improve test_bin_df_cols
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Oct 9, 2023
1 parent e7136e1 commit 43fcb6d
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,42 +186,56 @@ def test_save_fig(


@pytest.mark.parametrize(
"bin_by_cols, group_by_cols, n_bins, expected_n_bins",
"bin_by_cols, group_by_cols, n_bins, expected_n_bins, "
"verbose, kde_col, expected_n_rows",
[
(["col1"], [], 2, [2]),
(["col1", "col2"], [], 2, [2, 2]),
(["col1", "col2"], [], [2, 3], [2, 3]),
(["col1"], ["col2"], 2, [2]),
(["A"], [], 2, [2], True, "", 2),
(["A", "B"], [], 2, [2, 2], True, "kde", 4),
(["A", "B"], [], [2, 3], [2, 3], False, "kde", 5),
(["A"], ["B"], 2, [2], False, "", 30),
],
)
@pytest.mark.parametrize("verbose", [True, False])
def test_bin_df_cols(
bin_by_cols: list[str],
group_by_cols: list[str],
n_bins: int | list[int],
expected_n_bins: list[int],
verbose: bool,
kde_col: str,
expected_n_rows: int,
) -> None:
data = {"col1": [1, 2, 3, 4], "col2": [2, 3, 4, 5], "col3": [3, 4, 5, 6]}
df = pd.DataFrame(data)
df = pd._testing.makeDataFrame() # random data
idx_col = "index"
df.index.name = idx_col
bin_counts_col = "bin_counts"
df_binned = bin_df_cols(
df,
bin_by_cols,
group_by_cols,
n_bins,
verbose=verbose,
bin_counts_col=bin_counts_col,
kde_col=kde_col,
)

df_binned = bin_df_cols(df, bin_by_cols, group_by_cols, n_bins, verbose=verbose)
# ensure binned DataFrame has a minimum set of expected columns
expected_cols = {bin_counts_col, *df, *(f"{col}_bins" for col in bin_by_cols)}
assert {*df_binned} >= expected_cols
assert len(df_binned) == expected_n_rows

# validate the number of unique bins for each binned column
df_grouped = (
df.reset_index()
df.reset_index(names=idx_col)
.groupby([*[f"{c}_bins" for c in bin_by_cols], *group_by_cols])
.first()
.dropna()
)

for col, bins in zip(bin_by_cols, expected_n_bins):
for col, expected in zip(bin_by_cols, expected_n_bins):
binned_col = f"{col}_bins"
assert binned_col in df_grouped.index.names

unique_bins = df_grouped.index.get_level_values(binned_col).nunique()
assert unique_bins <= bins

assert not df_binned.empty
uniq_bins = df_grouped.index.get_level_values(binned_col).nunique()
assert uniq_bins == expected


def test_bin_df_cols_raises_value_error() -> None:
Expand Down

0 comments on commit 43fcb6d

Please sign in to comment.