Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAR+ docstring patch #1648

Merged
merged 2 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions contrib/sarplus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ train_df = spark.createDataFrame(
# spark dataframe with user/item tuples
test_df = spark.createDataFrame(
[(1, 1, 1), (3, 3, 1)],
["user_id", "item_id', "rating"],
["user_id", "item_id", "rating"],
)

# To use C++ based fast prediction, a local cache directory needs to be
Expand All @@ -91,6 +91,7 @@ test_df = spark.createDataFrame(
# col_item="item_id",
# col_rating="rating",
# col_timestamp="timestamp",
# similarity_type="jaccard",
# cache_path="cache",
# )
# ```
Expand All @@ -104,6 +105,7 @@ test_df = spark.createDataFrame(
# col_item="item_id",
# col_rating="rating",
# col_timestamp="timestamp",
# similarity_type="jaccard",
# cache_path="dbfs:/mnt/sarpluscache/cache",
# )
# ```
Expand All @@ -118,6 +120,7 @@ test_df = spark.createDataFrame(
# col_item="item_id",
# col_rating="rating",
# col_timestamp="timestamp",
# similarity_type="jaccard",
# cache_path=f"synfs:/{job_id}/mnt/sarpluscache/cache",
# )
# ```
Expand All @@ -134,16 +137,17 @@ model = SARPlus(
col_item="item_id",
col_rating="rating",
col_timestamp="timestamp",
similarity_type="jaccard",
)
model.fit(train_df, similarity_type="jaccard")
model.fit(train_df)

# To use C++ based fast prediction, the `use_cache` parameter of
# `SARPlus.recommend_k_items()` also needs to be set to `True`.
#
# ```
# model.recommend_k_items(test_df, top_k=3, use_cache=True).show()
# ```
model.recommend_k_items(test_df, top_k=3).show()
model.recommend_k_items(test_df, top_k=3, remove_seen=False).show()
```

### Jupyter Notebook
Expand Down
14 changes: 9 additions & 5 deletions contrib/sarplus/python/pysarplus/SARPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,15 @@ def recommend_k_items(
"""Recommend top K items for all users which are in the test set.

Args:
test: test Spark dataframe
top_k: top n items to return
remove_seen: remove items test users have already seen in the past from the recommended set
use_cache: use specified local directory stored in self.cache_path as cache for C++ based fast predictions
n_user_prediction_partitions: prediction partitions
test (pyspark.sql.DataFrame): test Spark dataframe.
top_k (int): top n items to return.
remove_seen (bool): remove items test users have already seen in the past from the recommended set.
use_cache (bool): use specified local directory stored in `self.cache_path` as cache for C++ based fast
predictions.
n_user_prediction_partitions (int): prediction partitions.

Returns:
pyspark.sql.DataFrame: Spark dataframe with recommended items
"""
if not use_cache:
return self._recommend_k_items_slow(test, top_k, remove_seen)
Expand Down