Skip to content

Commit

Permalink
Persist intermediate data to avoid non-determinism caused by Spark la…
Browse files Browse the repository at this point in the history
…zy random evaluation (#1676)

* Generate random number independent of the data to be splitted

* Persist intermediate Spark data frame to avoid errors caused by lazy evaluation

* Typos: e.g -> e.g.

* Compact code

* Add intersection tests for spark_stratified_split

* Add comments

* Drop duplicate data in tests
  • Loading branch information
simonzhaoms authored Mar 18, 2022
1 parent 17204b1 commit d43e2c1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
14 changes: 7 additions & 7 deletions recommenders/datasets/spark_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

try:
from pyspark.sql import functions as F, Window
from pyspark.storagelevel import StorageLevel
except ImportError:
pass # skip this import if we are in pure python environment

Expand Down Expand Up @@ -112,8 +113,8 @@ def _do_stratification_spark(
split_by = col_user if filter_by == "user" else col_item
partition_by = split_by if is_partitioned else []

col_random = "_random"
if is_random:
col_random = "_random"
data = data.withColumn(col_random, F.rand(seed=seed))
order_by = F.col(col_random)
else:
Expand All @@ -125,11 +126,10 @@ def _do_stratification_spark(
data = (
data.withColumn("_count", F.count(split_by).over(window_count))
.withColumn("_rank", F.row_number().over(window_spec) / F.col("_count"))
.drop("_count")
.drop("_count", col_random)
)

if is_random:
data = data.drop(col_random)
# Persist to avoid duplicate rows in splits caused by lazy evaluation
data.persist(StorageLevel.MEMORY_AND_DISK_2).count()

multi_split, ratio = process_split_ratio(ratio)
ratio = ratio if multi_split else [ratio, 1 - ratio]
Expand Down Expand Up @@ -215,7 +215,7 @@ def spark_stratified_split(
data into several portions corresponding to the split ratios. If a list is
provided and the ratios are not summed to 1, they will be normalized.
Earlier indexed splits will have earlier times
(e.g the latest time per user or item in split[0] <= the earliest time per user or item in split[1])
(e.g. the latest time per user or item in split[0] <= the earliest time per user or item in split[1])
seed (int): Seed.
min_rating (int): minimum number of ratings for user or item.
filter_by (str): either "user" or "item", depending on which of the two is to filter
Expand Down Expand Up @@ -257,7 +257,7 @@ def spark_timestamp_split(
data into several portions corresponding to the split ratios. If a list is
provided and the ratios are not summed to 1, they will be normalized.
Earlier indexed splits will have earlier times
(e.g the latest time in split[0] <= the earliest time in split[1])
(e.g. the latest time in split[0] <= the earliest time in split[1])
col_user (str): column name of user IDs.
col_item (str): column name of item IDs.
col_timestamp (str): column name of timestamps. Float number represented in
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/recommenders/datasets/test_spark_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,22 @@ def test_chrono_splitter(spark_dataset):

@pytest.mark.spark
def test_stratified_splitter(spark_dataset):
spark_dataset = spark_dataset.dropDuplicates()

splits = spark_stratified_split(
spark_dataset, ratio=RATIOS[0], filter_by="user", min_rating=10
)

assert splits[0].count() / NUM_ROWS == pytest.approx(RATIOS[0], TOL)
assert splits[1].count() / NUM_ROWS == pytest.approx(1 - RATIOS[0], TOL)

# Test if there is intersection
assert splits[0].intersect(splits[1]).count() == 0
splits = spark_stratified_split(
spark_dataset.repartition(4), ratio=RATIOS[0], filter_by="user", min_rating=10
)
assert splits[0].intersect(splits[1]).count() == 0

# Test if both contains the same user list. This is because stratified split is stratified.
users_train = (
splits[0].select(DEFAULT_USER_COL).distinct().rdd.map(lambda r: r[0]).collect()
Expand All @@ -140,6 +149,15 @@ def test_stratified_splitter(spark_dataset):
assert splits[1].count() / NUM_ROWS == pytest.approx(RATIOS[1], TOL)
assert splits[2].count() / NUM_ROWS == pytest.approx(RATIOS[2], TOL)

# Test if there is intersection
assert splits[0].intersect(splits[1]).count() == 0
assert splits[0].intersect(splits[2]).count() == 0
assert splits[1].intersect(splits[2]).count() == 0
splits = spark_stratified_split(spark_dataset.repartition(9), ratio=RATIOS)
assert splits[0].intersect(splits[1]).count() == 0
assert splits[0].intersect(splits[2]).count() == 0
assert splits[1].intersect(splits[2]).count() == 0


@pytest.mark.spark
def test_timestamp_splitter(spark_dataset):
Expand Down

0 comments on commit d43e2c1

Please sign in to comment.