Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial fit #29

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bd1a4ec
Add ANN interface and implementations
netang Mar 16, 2023
8eeeff8
Add dependencies
netang Mar 16, 2023
6051194
Add `PartialFitMixin`. Implement partial fit for cluster and knn models
netang Mar 16, 2023
99f71dc
Implement partial fit for association rules, pop rec, random rec, ucb…
netang Mar 17, 2023
34b4f0f
Implement partial fit for wilson
netang Mar 17, 2023
6b22b20
Add `_predict` to `ItemKNN`
netang Mar 17, 2023
9089d13
Merge branch 'sb-main-ann' into sb-main-fit-partial
netang Mar 17, 2023
c5f648c
Merge remote-tracking branch 'sb-repo/main' into sb-main-fit-partial
netang Aug 10, 2023
d6db553
Fix pylint warns. Fix test of refit UCB model. Remove old files. Add …
netang Aug 13, 2023
d4749fd
Fix pylint warns. Temporarily disable some pylint warns. Fix index re…
netang Aug 24, 2023
71983d4
Add `fit_partial` method for ThompsonSampling model
netang Aug 24, 2023
cfe0f42
Fix test of UCB model. Fix text of RandomRec model.
netang Aug 24, 2023
bd3b570
Comment out `_predict` method for `AssociationRulesItemRec`
netang Aug 24, 2023
c80edce
Add check for None
netang Aug 24, 2023
0232734
Return `sample = True`. Return previous predict result in doctest.
netang Sep 6, 2023
baf8acc
Merge remote-tracking branch 'sb-repo/main' into sb-main-fit-partial
netang Sep 6, 2023
ee624b4
Fix imports
netang Sep 6, 2023
39d5078
Return test params. Add new model creation for reproducibility test.
netang Sep 6, 2023
6cfbb17
Fix (maybe)
netang Sep 6, 2023
b6c3ad3
Add `test_fit_partial`
netang Sep 17, 2023
fda5ad6
Add `test_fit_partial` for `AssociationRulesItemRec`
netang Oct 1, 2023
0a87366
Add `test_fit_partial_with_ann` test
netang Oct 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 170 additions & 112 deletions replay/models/association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
from replay.models.base_neighbour_rec import NeighbourRec
from replay.models.base_rec import PartialFitMixin
from replay.utils.spark_utils import unpersist_after, unionify


# pylint: disable=too-many-ancestors, too-many-instance-attributes
class AssociationRulesItemRec(NeighbourRec):
class AssociationRulesItemRec(NeighbourRec, PartialFitMixin):
"""
Item-to-item recommender based on association rules.
Calculate pairs confidence, lift and confidence_gain defined as
Expand Down Expand Up @@ -77,7 +79,7 @@ def _get_ann_infer_params(self) -> Dict[str, Any]:

can_predict_item_to_item = True
item_to_item_metrics: List[str] = ["lift", "confidence", "confidence_gain"]
similarity: DataFrame
similarity: Optional[DataFrame] = None
can_change_metric = True
_search_space = {
"min_item_count": {"type": "int", "args": [3, 10]},
Expand Down Expand Up @@ -129,6 +131,8 @@ def __init__(
self.index_builder = index_builder
elif isinstance(index_builder, dict):
self.init_builder_from_dict(index_builder)
self.items_aggr: Optional[DataFrame] = None
self.session_col_unique_vals: Optional[DataFrame] = None

@property
def _init_args(self):
Expand All @@ -141,141 +145,191 @@ def _init_args(self):
"similarity_metric": self.similarity_metric,
}

def _fit(
def _fit_partial(
self,
log: DataFrame,
user_features: Optional[DataFrame] = None,
item_features: Optional[DataFrame] = None,
previous_log: Optional[DataFrame] = None,
) -> None:
"""
1) Filter log items by ``min_item_count`` threshold
2) Calculate items support, pairs confidence, lift and confidence_gain defined as
confidence(a, b)/confidence(!a, b).
"""
log = (
log.withColumn(
"relevance",
sf.col("relevance") if self.use_relevance else sf.lit(1),
with unpersist_after(self._dataframes):
log = (
log.withColumn(
"relevance",
sf.col("relevance") if self.use_relevance else sf.lit(1),
)
.select(self.session_col, "item_idx", "relevance")
.distinct()
)
.select(self.session_col, "item_idx", "relevance")
.distinct()
)
num_sessions = log.select(self.session_col).distinct().count()
if previous_log:
previous_log = (
previous_log.withColumn(
"relevance",
sf.col("relevance") if self.use_relevance else sf.lit(1),
)
.select(self.session_col, "item_idx", "relevance")
.distinct()
)
self.session_col_unique_vals = unionify(
log.select(self.session_col), self.session_col_unique_vals
).distinct()
self.session_col_unique_vals = self.session_col_unique_vals.cache()
num_sessions = self.session_col_unique_vals.count()

frequent_items_cached = (
log.groupBy("item_idx")
.agg(
items_aggr = log.groupby("item_idx").agg(
sf.count("item_idx").alias("item_count"),
sf.sum("relevance").alias("item_relevance"),
)
.filter(sf.col("item_count") >= self.min_item_count)
.drop("item_count")
).cache()

frequent_items_log = log.join(
frequent_items_cached.select("item_idx"), on="item_idx"
)

frequent_item_pairs = (
frequent_items_log.withColumnRenamed("item_idx", "antecedent")
.withColumnRenamed("relevance", "antecedent_rel")
.join(
frequent_items_log.withColumnRenamed(
self.session_col, self.session_col + "_cons"
self.items_aggr = (
unionify(items_aggr, self.items_aggr)
.groupBy("item_idx")
.agg(
sf.sum("item_count").alias("item_count"),
sf.sum("item_relevance").alias("item_relevance"),
)
.withColumnRenamed("item_idx", "consequent")
.withColumnRenamed("relevance", "consequent_rel"),
on=[
sf.col(self.session_col)
== sf.col(self.session_col + "_cons"),
sf.col("antecedent") < sf.col("consequent"),
],
)
# taking minimal relevance of item for pair
.withColumn(
"relevance",
sf.least(sf.col("consequent_rel"), sf.col("antecedent_rel")),
)
.drop(
self.session_col + "_cons", "consequent_rel", "antecedent_rel"
).cache()

frequent_items_cached = (
self.items_aggr.filter(
sf.col("item_count") >= self.min_item_count
).drop("item_count")
).cache()

frequent_items_log = unionify(log, previous_log).join(
frequent_items_cached.select("item_idx"), on="item_idx"
)
)

pairs_count = (
frequent_item_pairs.groupBy("antecedent", "consequent")
.agg(
sf.count("consequent").alias("pair_count"),
sf.sum("relevance").alias("pair_relevance"),
frequent_item_pairs = (
frequent_items_log.withColumnRenamed("item_idx", "antecedent")
.withColumnRenamed("relevance", "antecedent_rel")
.join(
frequent_items_log.withColumnRenamed(
self.session_col, self.session_col + "_cons"
)
.withColumnRenamed("item_idx", "consequent")
.withColumnRenamed("relevance", "consequent_rel"),
on=[
sf.col(self.session_col)
== sf.col(self.session_col + "_cons"),
sf.col("antecedent") < sf.col("consequent"),
],
)
# taking minimal relevance of item for pair
.withColumn(
"relevance",
sf.least(
sf.col("consequent_rel"), sf.col("antecedent_rel")
),
)
.drop(
self.session_col + "_cons",
"consequent_rel",
"antecedent_rel",
)
)
.filter(sf.col("pair_count") >= self.min_pair_count)
).drop("pair_count")

pairs_metrics = pairs_count.unionByName(
pairs_count.select(
sf.col("consequent").alias("antecedent"),
sf.col("antecedent").alias("consequent"),
sf.col("pair_relevance"),

pairs_count = (
frequent_item_pairs.groupBy("antecedent", "consequent")
.agg(
sf.count("consequent").alias("pair_count"),
sf.sum("relevance").alias("pair_relevance"),
)
.filter(sf.col("pair_count") >= self.min_pair_count)
).drop("pair_count")

pairs_metrics = pairs_count.unionByName(
pairs_count.select(
sf.col("consequent").alias("antecedent"),
sf.col("antecedent").alias("consequent"),
sf.col("pair_relevance"),
)
)
)

pairs_metrics = pairs_metrics.join(
frequent_items_cached.withColumnRenamed(
"item_relevance", "antecedent_relevance"
),
on=[sf.col("antecedent") == sf.col("item_idx")],
).drop("item_idx")

pairs_metrics = pairs_metrics.join(
frequent_items_cached.withColumnRenamed(
"item_relevance", "consequent_relevance"
),
on=[sf.col("consequent") == sf.col("item_idx")],
).drop("item_idx")

pairs_metrics = pairs_metrics.withColumn(
"confidence",
sf.col("pair_relevance") / sf.col("antecedent_relevance"),
).withColumn(
"lift",
num_sessions
* sf.col("confidence")
/ sf.col("consequent_relevance"),
)
pairs_metrics = pairs_metrics.join(
frequent_items_cached.withColumnRenamed(
"item_relevance", "antecedent_relevance"
),
on=[sf.col("antecedent") == sf.col("item_idx")],
).drop("item_idx")

if self.num_neighbours is not None:
pairs_metrics = (
pairs_metrics.withColumn(
"similarity_order",
sf.row_number().over(
Window.partitionBy("antecedent").orderBy(
sf.col("lift").desc(),
sf.col("consequent").desc(),
)
),
pairs_metrics = pairs_metrics.join(
frequent_items_cached.withColumnRenamed(
"item_relevance", "consequent_relevance"
),
on=[sf.col("consequent") == sf.col("item_idx")],
).drop("item_idx")

pairs_metrics = pairs_metrics.withColumn(
"confidence",
sf.col("pair_relevance") / sf.col("antecedent_relevance"),
).withColumn(
"lift",
num_sessions
* sf.col("confidence")
/ sf.col("consequent_relevance"),
)

if self.num_neighbours is not None:
pairs_metrics = (
pairs_metrics.withColumn(
"similarity_order",
sf.row_number().over(
Window.partitionBy("antecedent").orderBy(
sf.col("lift").desc(),
sf.col("consequent").desc(),
)
),
)
.filter(sf.col("similarity_order") <= self.num_neighbours)
.drop("similarity_order")
)
.filter(sf.col("similarity_order") <= self.num_neighbours)
.drop("similarity_order")

self.similarity = pairs_metrics.withColumn(
"confidence_gain",
sf.when(
sf.col("consequent_relevance") - sf.col("pair_relevance")
== 0,
sf.lit(np.inf),
).otherwise(
sf.col("confidence")
* (num_sessions - sf.col("antecedent_relevance"))
/ (
sf.col("consequent_relevance")
- sf.col("pair_relevance")
)
),
).select(
sf.col("antecedent").alias("item_idx_one"),
sf.col("consequent").alias("item_idx_two"),
"confidence",
"lift",
"confidence_gain",
)

self.similarity = pairs_metrics.withColumn(
"confidence_gain",
sf.when(
sf.col("consequent_relevance") - sf.col("pair_relevance") == 0,
sf.lit(np.inf),
).otherwise(
sf.col("confidence")
* (num_sessions - sf.col("antecedent_relevance"))
/ (sf.col("consequent_relevance") - sf.col("pair_relevance"))
),
).select(
sf.col("antecedent").alias("item_idx_one"),
sf.col("consequent").alias("item_idx_two"),
"confidence",
"lift",
"confidence_gain",
)
self.similarity.cache().count()
frequent_items_cached.unpersist()
self.similarity.cache().count()
frequent_items_cached.unpersist()

# pylint: disable=too-many-arguments
# def _predict(
# self,
# log: DataFrame,
# k: int,
# users: DataFrame,
# items: DataFrame,
# user_features: Optional[DataFrame] = None,
# item_features: Optional[DataFrame] = None,
# filter_seen_items: bool = True,
# ) -> None:
# raise NotImplementedError(
# f"item-to-user predict is not implemented for {self}, "
# f"use get_nearest_items method to get item-to-item recommendations"
# )

@property
def get_similarity(self):
Expand Down Expand Up @@ -349,4 +403,8 @@ def _get_nearest_items(

@property
def _dataframes(self):
return {"similarity": self.similarity}
return {
"similarity": self.similarity,
"items_aggr": self.items_aggr,
"session_col_unique_vals": self.session_col_unique_vals,
}
3 changes: 2 additions & 1 deletion replay/models/base_neighbour_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def _dataframes(self):

def _clear_cache(self):
if hasattr(self, "similarity"):
self.similarity.unpersist()
if self.similarity:
self.similarity.unpersist()

# pylint: disable=missing-function-docstring
@property
Expand Down
Loading
Loading