diff --git a/CHANGELOG.md b/CHANGELOG.md index e2968e84..e42b5885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- Configs for Popular, PopularInCategory models ([#188](https://github.com/MobileTeleSystems/RecTools/pull/188)) - Configs for EASE, Random, PureSVD models ([#178](https://github.com/MobileTeleSystems/RecTools/pull/178)) - Configs for implicit models ([#167](https://github.com/MobileTeleSystems/RecTools/pull/167)) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 746fc9e4..29708b10 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -20,6 +20,8 @@ import numpy as np import pandas as pd +import typing_extensions as tpe +from pydantic import PlainSerializer, PlainValidator from tqdm.auto import tqdm from rectools import Columns, InternalIds @@ -41,17 +43,89 @@ class Popularity(Enum): SUM_WEIGHT = "sum_weight" +def _deserialize_timedelta(td: tp.Union[dict, timedelta]) -> timedelta: + if isinstance(td, dict): + return timedelta(**td) + return td + + +def _serialize_timedelta(td: timedelta) -> dict: + serialized_td = { + key: value + for key, value in {"days": td.days, "seconds": td.seconds, "microseconds": td.microseconds}.items() + if value != 0 + } + return serialized_td + + +TimeDelta = tpe.Annotated[ + timedelta, + PlainValidator(func=_deserialize_timedelta), + PlainSerializer(func=_serialize_timedelta), +] + + class PopularModelConfig(ModelConfig): """Config for `PopularModel`.""" popularity: Popularity = Popularity.N_USERS - period: tp.Optional[timedelta] = None + period: tp.Optional[TimeDelta] = None begin_from: tp.Optional[datetime] = None add_cold: bool = False inverse: bool = False -class PopularModel(FixedColdRecoModelMixin, ModelBase): +PopularityOptions = tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] + + +class PopularModelMixin: + """Mixin for models based on popularity.""" + + @classmethod + def _validate_popularity( + cls, + popularity: PopularityOptions, + ) -> Popularity: + try: + return Popularity(popularity) + except ValueError: + possible_values = {item.value for item in Popularity.__members__.values()} + raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") + + @classmethod + def _validate_time_attributes( + cls, + period: tp.Optional[TimeDelta], + begin_from: tp.Optional[datetime], + ) -> None: + if period is not None and begin_from is not None: + raise ValueError("Only one of `period` and `begin_from` can be set") + + @classmethod + def _filter_interactions( + cls, interactions: pd.DataFrame, period: tp.Optional[TimeDelta], begin_from: tp.Optional[datetime] + ) -> pd.DataFrame: + if begin_from is not None: + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + elif period is not None: + begin_from = interactions[Columns.Datetime].max() - period + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + return interactions + + @classmethod + def _get_groupby_col_and_agg_func(cls, popularity: Popularity) -> tp.Tuple[str, str]: + if popularity == Popularity.N_USERS: + return Columns.User, "nunique" + if popularity == Popularity.N_INTERACTIONS: + return Columns.User, "count" + if popularity == Popularity.MEAN_WEIGHT: + return Columns.Weight, "mean" + if popularity == Popularity.SUM_WEIGHT: + return Columns.Weight, "sum" + raise ValueError(f"Unexpected popularity {popularity}") + + +class PopularModel(FixedColdRecoModelMixin, PopularModelMixin, ModelBase[PopularModelConfig]): """ Model generating recommendations based on popularity of items. @@ -87,25 +161,22 @@ class PopularModel(FixedColdRecoModelMixin, ModelBase): recommends_for_warm = False recommends_for_cold = True + config_class = PopularModelConfig + def __init__( self, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", + popularity: PopularityOptions = "n_users", period: tp.Optional[timedelta] = None, begin_from: tp.Optional[datetime] = None, add_cold: bool = False, inverse: bool = False, verbose: int = 0, ): - super().__init__(verbose=verbose) - - try: - self.popularity = Popularity(popularity) - except ValueError: - possible_values = {item.value for item in Popularity.__members__.values()} - raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") - - if period is not None and begin_from is not None: - raise ValueError("Only one of `period` and `begin_from` can be set") + super().__init__( + verbose=verbose, + ) + self.popularity = self._validate_popularity(popularity) + self._validate_time_attributes(period, begin_from) self.period = period self.begin_from = begin_from @@ -114,16 +185,29 @@ def __init__( self.popularity_list: tp.Tuple[InternalIdsArray, ScoresArray] - def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: - if self.begin_from is not None: - interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] - elif self.period is not None: - begin_from = interactions[Columns.Datetime].max() - self.period - interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] - return interactions + def _get_config(self) -> PopularModelConfig: + return PopularModelConfig( + popularity=self.popularity, + period=self.period, + begin_from=self.begin_from, + add_cold=self.add_cold, + inverse=self.inverse, + verbose=self.verbose, + ) + + @classmethod + def _from_config(cls, config: PopularModelConfig) -> tpe.Self: + return cls( + popularity=config.popularity.value, + period=config.period, + begin_from=config.begin_from, + add_cold=config.add_cold, + inverse=config.inverse, + verbose=config.verbose, + ) def _fit(self, dataset: Dataset) -> None: # type: ignore - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) col, func = self._get_groupby_col_and_agg_func(self.popularity) items_scores = interactions.groupby(Columns.Item)[col].agg(func).sort_values(ascending=False) @@ -141,18 +225,6 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.popularity_list = (items, scores) - @classmethod - def _get_groupby_col_and_agg_func(cls, popularity: Popularity) -> tp.Tuple[str, str]: - if popularity == Popularity.N_USERS: - return Columns.User, "nunique" - if popularity == Popularity.N_INTERACTIONS: - return Columns.User, "count" - if popularity == Popularity.MEAN_WEIGHT: - return Columns.Weight, "mean" - if popularity == Popularity.SUM_WEIGHT: - return Columns.Weight, "sum" - raise ValueError(f"Unexpected popularity {popularity}") - def _recommend_u2i( self, user_ids: InternalIdsArray, diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index e860295f..4f6416c4 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -21,13 +21,14 @@ import numpy as np import pandas as pd +import typing_extensions as tpe from rectools import Columns, InternalIds from rectools.dataset import Dataset, Interactions, features from rectools.types import InternalIdsArray -from .base import Scores -from .popular import PopularModel +from .base import ModelBase, Scores +from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, PopularityOptions class MixingStrategy(Enum): @@ -44,7 +45,18 @@ class RatioStrategy(Enum): PROPORTIONAL = "proportional" -class PopularInCategoryModel(PopularModel): +class PopularInCategoryModelConfig(PopularModelConfig): + """Config for `PopularInCategoryModel`.""" + + category_feature: str + n_categories: tp.Optional[int] = None + mixing_strategy: MixingStrategy = MixingStrategy.ROTATE + ratio_strategy: RatioStrategy = RatioStrategy.PROPORTIONAL + + +class PopularInCategoryModel( + FixedColdRecoModelMixin, PopularModelMixin, ModelBase[PopularInCategoryModelConfig] +): # pylint: disable=too-many-instance-attributes """ Model generating recommendations based on popularity of items. @@ -98,13 +110,15 @@ class PopularInCategoryModel(PopularModel): recommends_for_warm = False recommends_for_cold = True + config_class = PopularInCategoryModelConfig + def __init__( self, category_feature: str, n_categories: tp.Optional[int] = None, mixing_strategy: tp.Literal["rotate", "group"] = "rotate", ratio_strategy: tp.Literal["proportional", "equal"] = "proportional", - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", + popularity: PopularityOptions = "n_users", period: tp.Optional[timedelta] = None, begin_from: tp.Optional[datetime] = None, add_cold: bool = False, @@ -112,26 +126,18 @@ def __init__( verbose: int = 0, ): super().__init__( - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=add_cold, - inverse=inverse, verbose=verbose, ) - self.category_feature = category_feature - self.category_columns: tp.List[int] = [] - self.category_interactions: tp.Dict[int, pd.DataFrame] = {} - self.category_scores: pd.Series - self.models: tp.Dict[int, PopularModel] = {} - self.n_effective_categories: int + self.popularity = self._validate_popularity(popularity) + self._validate_time_attributes(period, begin_from) + self.period = period + self.begin_from = begin_from - if n_categories is None or n_categories > 0: - self.n_categories = n_categories - else: - raise ValueError(f"`n_categories` must be a positive number. Got {n_categories}") + self.add_cold = add_cold + self.inverse = inverse + self.category_feature = category_feature try: self.mixing_strategy = MixingStrategy(mixing_strategy) except ValueError: @@ -143,6 +149,45 @@ def __init__( except ValueError: possible_values = {item.value for item in RatioStrategy.__members__.values()} raise ValueError(f"`ratio_strategy` must be one of the {possible_values}. Got {ratio_strategy}.") + self.category_columns: tp.List[int] = [] + self.category_interactions: tp.Dict[int, pd.DataFrame] = {} + self.category_scores: pd.Series + self.models: tp.Dict[int, PopularModel] = {} + self.n_effective_categories: int + + if n_categories is None or n_categories > 0: + self.n_categories = n_categories + else: + raise ValueError(f"`n_categories` must be a positive number. Got {n_categories}") + + def _get_config(self) -> PopularInCategoryModelConfig: + return PopularInCategoryModelConfig( + category_feature=self.category_feature, + n_categories=self.n_categories, + mixing_strategy=self.mixing_strategy, + ratio_strategy=self.ratio_strategy, + popularity=self.popularity, + period=self.period, + begin_from=self.begin_from, + add_cold=self.add_cold, + inverse=self.inverse, + verbose=self.verbose, + ) + + @classmethod + def _from_config(cls, config: PopularInCategoryModelConfig) -> tpe.Self: + return cls( + category_feature=config.category_feature, + n_categories=config.n_categories, + mixing_strategy=config.mixing_strategy.value, + ratio_strategy=config.ratio_strategy.value, + popularity=config.popularity.value, + period=config.period, + begin_from=config.begin_from, + add_cold=config.add_cold, + inverse=config.inverse, + verbose=config.verbose, + ) def _check_category_feature(self, dataset: Dataset) -> None: if not dataset.item_features: @@ -200,7 +245,7 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.n_effective_categories = 0 self._check_category_feature(dataset) - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) self._calc_category_scores(dataset, interactions) self._define_categories_for_analysis() diff --git a/tests/model_selection/test_cross_validate.py b/tests/model_selection/test_cross_validate.py index d5d9dd87..b7bc374d 100644 --- a/tests/model_selection/test_cross_validate.py +++ b/tests/model_selection/test_cross_validate.py @@ -168,7 +168,7 @@ def setup_method(self) -> None: "intersection": Intersection(1), } - self.models = { + self.models: tp.Dict[str, ModelBase] = { "popular": PopularModel(), "random": RandomModel(random_state=42), } diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 0f90de38..9ea04f94 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -259,8 +259,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "num_threads": 1, "verbose": 1, } - model = EASEModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(EASEModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_implicit_als.py b/tests/models/test_implicit_als.py index da5dd34b..3ee309a0 100644 --- a/tests/models/test_implicit_als.py +++ b/tests/models/test_implicit_als.py @@ -461,8 +461,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N }, "verbose": 1, } - model = ImplicitALSWrapperModel(model=AlternatingLeastSquares()) - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(ImplicitALSWrapperModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, tp.Any] = {"model": {}} diff --git a/tests/models/test_implicit_knn.py b/tests/models/test_implicit_knn.py index 732e7808..db0efd53 100644 --- a/tests/models/test_implicit_knn.py +++ b/tests/models/test_implicit_knn.py @@ -320,8 +320,9 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N }, "verbose": 1, } - model = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility( + ImplicitItemKNNWrapperModel, DATASET, initial_config, simple_types + ) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, tp.Any] = {"model": {}} diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index e1ab4e8b..fd419c1a 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -22,7 +22,14 @@ from rectools import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import PopularModel -from tests.models.utils import assert_second_fit_refits_model +from rectools.models.popular import Popularity +from tests.models.utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) + +from .data import DATASET class TestPopularModel: @@ -212,3 +219,130 @@ def test_i2i( def test_second_fit_refits_model(self, dataset: Dataset) -> None: model = PopularModel() assert_second_fit_refits_model(model, dataset) + + +class TestPopularModelConfiguration: + @pytest.mark.parametrize( + "begin_from,period,expected_begin_from,expected_period", + ( + (None, timedelta(days=7), None, timedelta(days=7)), + (datetime(2021, 11, 23), None, datetime(2021, 11, 23), None), + ("2021-11-23T10:20:30.400", None, datetime(2021, 11, 23, 10, 20, 30, 400000), None), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + None, + timedelta(days=56, seconds=243, microseconds=44345), + ), + ), + ) + def test_from_config( + self, + period: tp.Optional[tp.Union[timedelta, dict]], + begin_from: tp.Optional[tp.Union[datetime, str]], + expected_begin_from: tp.Optional[datetime], + expected_period: tp.Optional[dict], + ) -> None: + config = { + "popularity": "n_interactions", + "period": period, + "begin_from": begin_from, + "add_cold": True, + "inverse": True, + "verbose": 0, + } + model = PopularModel.from_config(config) + assert model.popularity.value == "n_interactions" + assert model.period == expected_period + assert model.begin_from == expected_begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + + @pytest.mark.parametrize( + "begin_from,period,expected_period", + ( + ( + None, + timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + {"days": 21, "microseconds": 345000, "seconds": 82812}, + ), + (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), + ), + ) + def test_get_config( + self, + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], + expected_period: tp.Optional[timedelta], + ) -> None: + model = PopularModel( + popularity="n_users", + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "popularity": Popularity("n_users"), + "period": expected_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + + @pytest.mark.parametrize( + "begin_from,period,simple_types", + ( + ( + None, + timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6000, microseconds=70000), + True, + ), + (datetime(2021, 11, 23), None, False), + ("2021-11-23T10:20:30.400", None, True), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + False, + ), + ), + ) + def test_get_config_and_from_config_compatibility( + self, + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], + simple_types: bool, + ) -> None: + initial_config = { + "popularity": "n_users", + "period": period, + "begin_from": begin_from, + "add_cold": True, + "inverse": False, + "verbose": 0, + } + assert_get_config_and_from_config_compatibility(PopularModel, DATASET, initial_config, simple_types) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = PopularModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index 59f3b02d..3d0a6ffa 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -22,69 +22,78 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import PopularInCategoryModel -from tests.models.utils import assert_second_fit_refits_model +from rectools.models.popular import Popularity +from rectools.models.popular_in_category import MixingStrategy, RatioStrategy +from tests.models.utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) -@pytest.mark.filterwarnings("ignore") -class TestPopularInCategoryModel: - @pytest.fixture - def interactions_df(self) -> pd.DataFrame: - interactions_df = pd.DataFrame( - [ - [70, 11, 1, "2021-11-30"], - [70, 12, 1, "2021-11-30"], - [10, 11, 1, "2021-11-30"], - [10, 12, 1, "2021-11-29"], - [10, 13, 9, "2021-11-28"], - [20, 11, 1, "2021-11-27"], - [20, 14, 2, "2021-11-26"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [30, 11, 1, "2021-11-24"], - [30, 12, 1, "2021-11-23"], - [30, 14, 1, "2021-11-23"], - [30, 15, 5, "2021-11-21"], - [30, 15, 5, "2021-11-21"], - [40, 11, 1, "2021-11-20"], - [40, 12, 1, "2021-11-19"], - [50, 12, 1, "2021-11-19"], - [60, 12, 1, "2021-11-19"], - ], - columns=Columns.Interactions, - ) - return interactions_df +@pytest.fixture(name="interactions_df") # https://github.com/pylint-dev/pylint/issues/6531 +def _interactions_df() -> pd.DataFrame: + interactions_df = pd.DataFrame( + [ + [70, 11, 1, "2021-11-30"], + [70, 12, 1, "2021-11-30"], + [10, 11, 1, "2021-11-30"], + [10, 12, 1, "2021-11-29"], + [10, 13, 9, "2021-11-28"], + [20, 11, 1, "2021-11-27"], + [20, 14, 2, "2021-11-26"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [30, 11, 1, "2021-11-24"], + [30, 12, 1, "2021-11-23"], + [30, 14, 1, "2021-11-23"], + [30, 15, 5, "2021-11-21"], + [30, 15, 5, "2021-11-21"], + [40, 11, 1, "2021-11-20"], + [40, 12, 1, "2021-11-19"], + [50, 12, 1, "2021-11-19"], + [60, 12, 1, "2021-11-19"], + ], + columns=Columns.Interactions, + ) + return interactions_df - @pytest.fixture - def item_features_df(self) -> pd.DataFrame: - item_features_df = pd.DataFrame( - { - "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], - "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], - "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], - } - ) - return item_features_df - @pytest.fixture - def dataset(self, interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: - user_features_df = pd.DataFrame( - { - "id": [10, 50], - "feature": ["f1", "f1"], - "value": [1, 1], - } - ) - dataset = Dataset.construct( - interactions_df=interactions_df, - user_features_df=user_features_df, - item_features_df=item_features_df, - cat_item_features=["f2", "f1"], - ) - return dataset +@pytest.fixture(name="item_features_df") +def _item_features_df() -> pd.DataFrame: + item_features_df = pd.DataFrame( + { + "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], + "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], + "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], + } + ) + return item_features_df + +@pytest.fixture(name="dataset") +def _dataset(interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: + user_features_df = pd.DataFrame( + { + "id": [10, 50], + "feature": ["f1", "f1"], + "value": [1, 1], + } + ) + dataset = Dataset.construct( + interactions_df=interactions_df, + user_features_df=user_features_df, + item_features_df=item_features_df, + cat_item_features=["f2", "f1"], + ) + return dataset + + +@pytest.mark.filterwarnings("ignore") +class TestPopularInCategoryModel: @classmethod def assert_reco( cls, @@ -444,3 +453,151 @@ def test_second_fit_refits_model( n_categories=n_categories, ) assert_second_fit_refits_model(model, dataset) + + +class TestPopularInCategoryModelConfiguration: + @pytest.mark.parametrize( + "begin_from,period,expected_begin_from,expected_period", + ( + (None, timedelta(days=7), None, timedelta(days=7)), + (datetime(2021, 11, 23), None, datetime(2021, 11, 23), None), + ("2021-11-23T10:20:30.400", None, datetime(2021, 11, 23, 10, 20, 30, 400000), None), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + None, + timedelta(days=56, seconds=243, microseconds=44345), + ), + ), + ) + def test_from_config( + self, + period: tp.Optional[tp.Union[timedelta, dict]], + begin_from: tp.Optional[tp.Union[datetime, str]], + expected_begin_from: tp.Optional[datetime], + expected_period: tp.Optional[dict], + ) -> None: + config = { + "category_feature": "f1", + "n_categories": 2, + "mixing_strategy": "group", + "ratio_strategy": "equal", + "popularity": "n_interactions", + "period": period, + "begin_from": begin_from, + "add_cold": True, + "inverse": True, + "verbose": 0, + } + model = PopularInCategoryModel.from_config(config) + assert model.category_feature == "f1" + assert model.n_categories == 2 + assert model.mixing_strategy == MixingStrategy("group") + assert model.ratio_strategy == RatioStrategy("equal") + assert model.popularity == Popularity("n_interactions") + assert model.period == expected_period + assert model.begin_from == expected_begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + + @pytest.mark.parametrize( + "begin_from,period,expected_period", + ( + ( + None, + timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + {"days": 21, "microseconds": 345000, "seconds": 82812}, + ), + (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), + ), + ) + def test_get_config( + self, + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], + expected_period: tp.Optional[timedelta], + ) -> None: + model = PopularInCategoryModel( + category_feature="f2", + n_categories=3, + mixing_strategy="rotate", + ratio_strategy="proportional", + popularity="n_users", + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "category_feature": "f2", + "n_categories": 3, + "mixing_strategy": MixingStrategy("rotate"), + "ratio_strategy": RatioStrategy("proportional"), + "popularity": Popularity("n_users"), + "period": expected_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + + @pytest.mark.parametrize( + "begin_from,period,simple_types", + ( + ( + None, + timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6000, microseconds=70000), + True, + ), + (datetime(2021, 11, 23), None, False), + ("2021-11-23T10:20:30.400", None, True), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + False, + ), + ), + ) + def test_get_config_and_from_config_compatibility( + self, + dataset: Dataset, + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], + simple_types: bool, + ) -> None: + initial_config = { + "category_feature": "f1", + "n_categories": 2, + "mixing_strategy": "group", + "ratio_strategy": "equal", + "popularity": "n_users", + "period": period, + "begin_from": begin_from, + "add_cold": True, + "inverse": False, + "verbose": 0, + } + assert_get_config_and_from_config_compatibility(PopularInCategoryModel, dataset, initial_config, simple_types) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, str] = {"category_feature": "f2"} + model = PopularInCategoryModel(category_feature="f2") + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 14598ad5..7842c131 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -304,8 +304,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "random_state": 32, "verbose": 0, } - model = PureSVDModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(PureSVDModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_random.py b/tests/models/test_random.py index f55fa6b5..373ee9fe 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -214,8 +214,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "random_state": 32, "verbose": 0, } - model = RandomModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(RandomModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/utils.py b/tests/models/utils.py index ec531b55..92f2757d 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -58,7 +58,7 @@ def assert_default_config_and_default_model_params_are_the_same( def assert_get_config_and_from_config_compatibility( - model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool + model: tp.Type[ModelBase], dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool ) -> None: def get_reco(model: ModelBase) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False)