diff --git a/rectools/models/base.py b/rectools/models/base.py index 79e528c8..a742ef41 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -154,7 +154,8 @@ def recommend( """ self._check_is_fitted() self._check_k(k) - + # `dataset.item_id_map.external_dtype` can change + original_item_type = dataset.item_id_map.external_dtype dataset = self._custom_transform_dataset_u2i(dataset, users, on_unsupported_targets) sorted_item_ids_to_recommend = self._get_sorted_item_ids_to_recommend(items_to_recommend, dataset) @@ -194,6 +195,14 @@ def recommend( reco_warm_final = self._reco_to_external(reco_warm, dataset.user_id_map, dataset.item_id_map) reco_cold_final = self._reco_items_to_external(reco_cold, dataset.item_id_map) + reco_hot_final = self._adjust_reco_types(reco_hot_final, dataset.user_id_map.external_dtype, original_item_type) + reco_warm_final = self._adjust_reco_types( + reco_warm_final, dataset.user_id_map.external_dtype, original_item_type + ) + reco_cold_final = self._adjust_reco_types( + reco_cold_final, dataset.user_id_map.external_dtype, original_item_type + ) + del reco_hot, reco_warm, reco_cold reco_all = self._concat_reco((reco_hot_final, reco_warm_final, reco_cold_final)) @@ -267,7 +276,8 @@ def recommend_to_items( # pylint: disable=too-many-branches """ self._check_is_fitted() self._check_k(k) - + # `dataset.item_id_map.external_dtype` can change + original_item_type = dataset.item_id_map.external_dtype dataset = self._custom_transform_dataset_i2i(dataset, target_items, on_unsupported_targets) sorted_item_ids_to_recommend = self._get_sorted_item_ids_to_recommend(items_to_recommend, dataset) @@ -318,6 +328,10 @@ def recommend_to_items( # pylint: disable=too-many-branches reco_cold_final = self._reco_items_to_external(reco_cold, dataset.item_id_map) del reco_hot, reco_warm, reco_cold + reco_hot_final = self._adjust_reco_types(reco_hot_final, original_item_type, original_item_type) + reco_warm_final = self._adjust_reco_types(reco_warm_final, original_item_type, original_item_type) + reco_cold_final = self._adjust_reco_types(reco_cold_final, original_item_type, original_item_type) + reco_all = self._concat_reco((reco_hot_final, reco_warm_final, reco_cold_final)) del reco_hot_final, reco_warm_final, reco_cold_final reco_df = self._make_reco_table(reco_all, Columns.TargetItem, add_rank_col) @@ -410,10 +424,12 @@ def _check_targets_are_valid( return hot_targets, warm_targets, cold_targets @classmethod - def _adjust_reco_types(cls, reco: RecoTriplet_T, target_type: tp.Type = np.int64) -> RecoTriplet_T: + def _adjust_reco_types( + cls, reco: RecoTriplet_T, target_type: tp.Type = np.int64, item_type: tp.Type = np.int64 + ) -> RecoTriplet_T: target_ids, item_ids, scores = reco target_ids = np.asarray(target_ids, dtype=target_type) - item_ids = np.asarray(item_ids, dtype=np.int64) + item_ids = np.asarray(item_ids, dtype=item_type) scores = np.asarray(scores, dtype=np.float32) return target_ids, item_ids, scores diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index 9bc2ba28..eb7d8d11 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import torch -import tqdm import typing_extensions as tpe from pytorch_lightning import LightningModule, Trainer from scipy import sparse @@ -24,7 +23,7 @@ PADDING_VALUE = "PAD" - +# pylint: disable=too-many-lines # #### -------------- Net blocks -------------- #### # @@ -32,16 +31,16 @@ class ItemNetBase(nn.Module): """TODO: use Protocol""" def forward(self, items: torch.Tensor) -> torch.Tensor: - """TODO""" + """Forward pass.""" raise NotImplementedError() @classmethod def from_dataset(cls, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> tpe.Self: - """TODO""" + """Construct ItemNet.""" raise NotImplementedError() def get_all_embeddings(self) -> torch.Tensor: - """TODO""" + """Return item embeddings.""" raise NotImplementedError() @property @@ -54,7 +53,7 @@ class TransformerLayersBase(nn.Module): """TODO: use Protocol""" def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """Forward""" + """Forward pass.""" raise NotImplementedError() @@ -62,7 +61,7 @@ class PositionalEncodingBase(torch.nn.Module): """TODO: use Protocol""" def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.Tensor: - """TODO""" + """Forward pass.""" raise NotImplementedError() @@ -132,8 +131,16 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> class IdEmbeddingsItemNet(ItemNetBase): """ - Base class for item embeddings. To use more complicated logic then just id embeddings inherit - from this class and pass your custom ItemNet to your model params. + Network for item embeddings based only on item ids. + + Parameters + ---------- + n_factors: int + Latent embedding size of item embeddings. + n_items: int + Number of items in the dataset. + dropout_rate: float + Probability of a hidden unit to be zeroed. """ def __init__(self, n_factors: int, n_items: int, dropout_rate: float): @@ -148,7 +155,19 @@ def __init__(self, n_factors: int, n_items: int, dropout_rate: float): self.drop_layer = nn.Dropout(dropout_rate) def forward(self, items: torch.Tensor) -> torch.Tensor: - """TODO""" + """ + Forward pass to get item embeddings from item ids. + + Parameters + ---------- + items: torch.Tensor + Internal item ids. + + Returns + ------- + torch.Tensor + Item embeddings. + """ item_embs = self.ids_emb(items) item_embs = self.drop_layer(item_embs) return item_embs @@ -204,11 +223,11 @@ def device(self) -> torch.device: @property def catalogue(self) -> torch.Tensor: - """TODO""" + """Return tensor with elements in range [0, n_items).""" return torch.arange(0, self.n_items, device=self.device) def get_all_embeddings(self) -> torch.Tensor: - """TODO""" + """Return item embeddings.""" return self.forward(self.catalogue) @classmethod @@ -231,10 +250,21 @@ def from_dataset( class PointWiseFeedForward(nn.Module): - """TODO""" + """ + Feed-Forward network to introduce nonlinearity into the transformer model. + This implementation is the one used by SASRec authors. + + Parameters + ---------- + n_factors: int + Latent embeddings size. + n_factors_ff: int + How many hidden units to use in the network. + dropout_rate: float + Probability of a hidden unit to be zeroed. + """ def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> None: - """TODO""" super().__init__() self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) self.ff_dropout1 = torch.nn.Dropout(dropout_rate) @@ -243,14 +273,39 @@ def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> No self.ff_dropout2 = torch.nn.Dropout(dropout_rate) def forward(self, seqs: torch.Tensor) -> torch.Tensor: - """TODO""" + """ + Forward pass. + + Parameters + ---------- + seqs: torch.Tensor + User sequences of item embeddings. + + Returns + ------- + torch.Tensor + User sequence that passed through all layers. + """ output = self.ff_relu(self.ff_dropout1(self.ff_linear1(seqs))) fin = self.ff_dropout2(self.ff_linear2(output)) return fin class SASRecTransformerLayers(TransformerLayersBase): - """Exactly SASRec authors architecture but with torch MHA realisation""" + """ + Exactly SASRec author's transformer blocks architecture but with pytorch Multi-Head Attention realisation. + + Parameters + ---------- + n_blocks: int + Number of transformer blocks. + n_factors: int + Latent embeddings size. + n_heads: int + Number of attention heads. + dropout_rate: float + Probability of a hidden unit to be zeroed. + """ def __init__( self, @@ -272,7 +327,23 @@ def __init__( self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8) def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """TODO""" + """ + Forward pass through transformer blocks. + + Parameters + ---------- + seqs: torch.Tensor + User sequences of item embeddings. + timeline_mask: torch.Tensor + Mask to zero out padding elements. + attn_mask: torch.Tensor + Mask to forbid model to use future interactions. + + Returns + ------- + torch.Tensor + User sequences passed through transformer layers. + """ for i in range(self.n_blocks): q = self.q_layer_norm[i](seqs) mha_output, _ = self.multi_head_attn[i](q, seqs, seqs, attn_mask=attn_mask, need_weights=False) @@ -289,9 +360,19 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to class PreLNTransformerLayers(TransformerLayersBase): """ - Based on https://arxiv.org/pdf/2002.04745 - On Kion open dataset didn't change metrics, even got a bit worse - But let's keep it for now + Architecture of transformer blocks based on https://arxiv.org/pdf/2002.04745 + On Kion open dataset didn't change metrics, even got a bit worse. + + Parameters + ---------- + n_blocks: int + Number of transformer blocks. + n_factors: int + Latent embeddings size. + n_heads: int + Number of attention heads. + dropout_rate: float + Probability of a hidden unit to be zeroed. """ def __init__( @@ -314,7 +395,23 @@ def __init__( ) def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """TODO""" + """ + Forward pass through transformer blocks. + + Parameters + ---------- + seqs: torch.Tensor + User sequences of item embeddings. + timeline_mask: torch.Tensor + Mask to zero out padding elements. + attn_mask: torch.Tensor + Forbid model to use future interactions. + + Returns + ------- + torch.Tensor + User sequences passed through transformer layers. + """ for i in range(self.n_blocks): mha_input = self.mha_layer_norm[i](seqs) mha_output, _ = self.multi_head_attn[i]( @@ -331,14 +428,39 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to class LearnableInversePositionalEncoding(PositionalEncodingBase): - """TODO""" + """ + Class to introduce learnable positional embeddings. + + Parameters + ---------- + use_pos_emb: bool + If ``True``, adds learnable positional encoding to session item embeddings. + session_max_len: int + Maximum length of user sequence. + n_factors: int + Latent embeddings size. + """ def __init__(self, use_pos_emb: bool, session_max_len: int, n_factors: int): super().__init__() self.pos_emb = torch.nn.Embedding(session_max_len, n_factors) if use_pos_emb else None def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.Tensor: - """TODO""" + """ + Forward pass to add learnable positional encoding to sessions and mask padding elements. + + Parameters + ---------- + sessions: torch.Tensor + User sessions in the form of sequences of items ids. + timeline_mask: torch.Tensor + Mask to zero out padding elements. + + Returns + ------- + torch.Tensor + Encoded user sessions with added positional encoding if `use_pos_emb` is ``True``. + """ batch_size, session_max_len, _ = sessions.shape if self.pos_emb is not None: @@ -360,7 +482,32 @@ def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch. class TransformerBasedSessionEncoder(torch.nn.Module): - """TODO""" + """ + Torch model for recommendations. + + Parameters + ---------- + n_blocks: int + Number of transformer blocks. + n_factors: int + Latent embeddings size. + n_heads: int + Number of attention heads. + session_max_len: int + Maximum length of user sequence. + dropout_rate: float + Probability of a hidden unit to be zeroed. + use_pos_emb: bool, default True + If ``True``, adds learnable positional encoding to session item embeddings. + use_causal_attn: bool, default True + If ``True``, causal mask is used in multi-head self-attention. + transformer_layers_type: Type(TransformerLayersBase), default `SasRecTransformerLayers` + Type of transformer layers architecture. + item_net_type: Type(ItemNetBase), default IdEmbeddingsItemNet + Type of network returning item embeddings. + pos_encoding_type: Type(PositionalEncodingBase), default LearnableInversePositionalEncoding + Type of positional encoding. + """ def __init__( self, @@ -393,19 +540,35 @@ def __init__( self.item_net_block_types = item_net_block_types def construct_item_net(self, dataset: Dataset) -> None: - """TODO""" + """ + Construct network for item embeddings from dataset. + + Parameters + ---------- + dataset: Dataset + RecTools dataset with user-item interactions. + """ self.item_model = ItemNetConstructor.from_dataset( dataset, self.n_factors, self.dropout_rate, self.item_net_block_types ) def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: """ - Pass user history through item embeddings and transformer blocks. + Pass user history through item embeddings. + Add positional encoding. + Pass history through transformer blocks. + + Parameters + ---------- + sessions: torch.Tensor + User sessions in the form of sequences of items ids. + item_embs: torch.Tensor + Item embeddings. Returns ------- - torch.Tensor. [batch_size, session_max_len, n_factors] - + torch.Tensor. [batch_size, session_max_len, n_factors] + Encoded session embeddings. """ session_max_len = sessions.shape[1] attn_mask = None @@ -424,7 +587,22 @@ def forward( self, sessions: torch.Tensor, # [batch_size, session_max_len] ) -> torch.Tensor: - """TODO""" + """ + Forward pass to get logits. + Get item embeddings. + Pass user sessions through transformer blocks. + Calculate logits. + + Parameters + ---------- + sessions: torch.Tensor + User sessions in the form of sequences of items ids. + + Returns + ------- + torch.Tensor + Logits. + """ item_embs = self.item_model.get_all_embeddings() # [n_items + 1, n_factors] session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] logits = session_embs @ item_embs.T # [batch_size, session_max_len, n_items + 1] @@ -435,7 +613,16 @@ def forward( class SequenceDataset(TorchDataset): - """TODO""" + """ + Dataset for sequential data. + + Parameters + ---------- + sessions: List[List[int]] + User sessions in the form of sequences of items ids. + weights: List[List[float]] + Weight of each interaction from the session. + """ def __init__(self, sessions: List[List[int]], weights: List[List[float]]): self.sessions = sessions @@ -454,7 +641,15 @@ def from_interactions( cls, interactions: pd.DataFrame, ) -> "SequenceDataset": - """TODO""" + """ + Group interactions by user. + Construct SequenceDataset from grouped interactions. + + Parameters + ---------- + interactions: pd.DataFrame + User-item interactions. + """ sessions = ( interactions.sort_values(Columns.Datetime) .groupby(Columns.User, sort=True)[[Columns.Item, Columns.Weight]] @@ -469,7 +664,25 @@ def from_interactions( class SessionEncoderDataPreparatorBase: - """Base class for data preparator. Used only for type hinting.""" + """ + Base class for data preparator. To change train/recommend dataset processing, train/recommend dataloaders inherit + from this class and pass your custom data preparator to your model parameters. + + Parameters + ---------- + session_max_len: int + Maximum length of user sequence. + batch_size: int + How many samples per batch to load. + dataloader_num_workers: int + Number of loader worker processes. + item_extra_tokens: Sequence(Hashable), default (PADDING_VALUE,) = ("PAD",) + Which element to use for sequence padding. + shuffle_train: bool, default True + If ``True``, reshuffles data at each epoch. + train_min_user_interactions: int, default 2 + Minimum length of user sequence. Cannot be less than 2. + """ def __init__( self, @@ -490,44 +703,58 @@ def __init__( # TODO: add SequenceDatasetType for fit and recommend def get_known_items_sorted_internal_ids(self) -> np.ndarray: - """TODO""" + """Return internal item ids from processed dataset in sorted order.""" return self.item_id_map.get_sorted_internal()[self.n_item_extra_tokens :] def get_known_item_ids(self) -> np.ndarray: - """TODO""" + """Return external item ids from processed dataset in sorted order.""" return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :] @property def n_item_extra_tokens(self) -> int: - """TODO""" + """Return number of padding elements""" return len(self.item_extra_tokens) def process_dataset_train(self, dataset: Dataset) -> Dataset: - """TODO""" + """Process train dataset.""" raise NotImplementedError() def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: - """TODO""" + """Return train dataloader.""" raise NotImplementedError() def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: - """TODO""" + """Return recommend dataloader.""" raise NotImplementedError() def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: - """TODO""" + """Process dataset for u2i recommendations.""" raise NotImplementedError() def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: - """TODO""" + """Process dataset for i2i recommendations.""" raise NotImplementedError() class SASRecDataPreparator(SessionEncoderDataPreparatorBase): - """TODO""" + """Class to process train/recommend datasets and prepare train/recommend dataloaders.""" def process_dataset_train(self, dataset: Dataset) -> Dataset: - """TODO""" + """ + Remove sequences shorter than ``train_min_user_interactions``. + Leave ``session_max_len`` + 1 most recent interactions. + Create new RecTools dataset with processed interactions. + + Parameters + ---------- + dataset: Dataset + RecTools dataset with train interactions. + + Returns + ------- + Dataset + RecTools dataset with processed interactions. + """ interactions = dataset.get_raw_interactions() # Filter interactions @@ -577,8 +804,8 @@ def _collate_fn_train( batch: List[Tuple[List[int], List[float]]], ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ - Truncate each session from right to keep (session_max_len+1) last items. - Do left padding until (session_max_len+1) is reached. + Truncate each session from right to keep (``session_max_len`` + 1) last items. + Do left padding until (``session_max_len`` + 1) is reached. Split to `x`, `y`, and `yw`. """ batch_size = len(batch) @@ -592,7 +819,19 @@ def _collate_fn_train( return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: - """TODO""" + """ + Construct train dataloader from processed dataset. + + Parameters + ---------- + processed_dataset: Dataset + RecTools dataset prepared for training. + + Returns + ------- + DataLoader + Train dataloader. + """ sequence_dataset = SequenceDataset.from_interactions(processed_dataset.interactions.df) train_dataloader = DataLoader( sequence_dataset, @@ -605,13 +844,26 @@ def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: """ + Process dataset for u2i recommendations. Filter out interactions and adapt id maps. - Final dataset will consist only of model known items during fit and only of required - (and supported) target users for recommendations. All users beyond target users for recommendations are dropped. All target users that do not have at least one known item in interactions are dropped. - Final user_id_map is an enumerated list of supported (filtered) target users - Final item_id_map is model item_id_map constructed during training + + Parameters + ---------- + dataset: Dataset + RecTools dataset. + users: ExternalIds + Array of external user ids to recommend for. + + Returns + ------- + Dataset + Processed RecTools dataset. + Final dataset will consist only of model known items during fit and only of required + (and supported) target users for recommendations. + Final user_id_map is an enumerated list of supported (filtered) target users. + Final item_id_map is model item_id_map constructed during training. """ # Filter interactions in dataset internal ids interactions = dataset.interactions.df @@ -639,10 +891,21 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: """ + Process dataset for i2i recommendations. Filter out interactions and adapt id maps. - Final dataset will consist only of model known items during fit. - Final user_id_map is the same as dataset original - Final item_id_map is model item_id_map constructed during training + + Parameters + ---------- + dataset: Dataset + RecTools dataset. + + Returns + ------- + Dataset + Processed RecTools dataset. + Final dataset will consist only of model known items during fit. + Final user_id_map is the same as dataset original. + Final item_id_map is model item_id_map constructed during training. """ # TODO: optimize by filtering in internal ids # TODO: For now features are dropped because model doesn't support them @@ -660,7 +923,19 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t return torch.LongTensor(x) def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: - """TODO""" + """ + Construct recommend dataloader from processed dataset. + + Parameters + ---------- + processed_dataset: Dataset + RecTools dataset. + + Returns + ------- + DataLoader + Recommend dataloader. + """ sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) recommend_dataloader = DataLoader( sequence_dataset, @@ -676,7 +951,21 @@ def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: class SessionEncoderLightningModuleBase(LightningModule): - """Base class for lightning module. Used only for type hinting.""" + """ + Base class for lightning module. To change train procedure inherit + from this class and pass your custom LightningModule to your model parameters. + + Parameters + ---------- + torch_model: TransformerBasedSessionEncoder + Torch model to make recommendations. + lr: float + Learning rate. + loss: str, default "softmax" + Loss function. + adam_betas: Tuple[float, float], default (0.9, 0.98) + Coefficients for running averages of gradient and its square. + """ def __init__( self, @@ -690,9 +979,10 @@ def __init__( self.loss = loss self.torch_model = torch_model self.adam_betas = adam_betas + self.item_embs: torch.Tensor def configure_optimizers(self) -> torch.optim.Adam: - """TODO""" + """Choose what optimizers and learning-rate schedulers to use in optimization""" optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) return optimizer @@ -700,23 +990,38 @@ def forward( self, batch: torch.Tensor, ) -> torch.Tensor: - """TODO""" + """Forward pass. Propagate the batch through torch_model.""" return self.torch_model(batch) def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: - """TODO""" + """Training step.""" raise NotImplementedError() class SessionEncoderLightningModule(SessionEncoderLightningModuleBase): - """TODO""" + """Lightning module to train SASRec model.""" def on_train_start(self) -> None: - """TODO""" + """Initialize parameters with values from Xavier normal distribution.""" self._xavier_normal_init() def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: - """TODO""" + """ + Training step. + Compute logits by propagating torch network. + Compute loss. + + Parameters + ---------- + batch: torch.Tensor + Batch containing user interaction sequences, target interactions, interaction weights. + batch_idx: int + Index of a batch. + + Returns + ------- + Loss. + """ x, y, w = batch logits = self.forward(x) # [batch_size, session_max_len, n_items + 1] if self.loss == "softmax": @@ -740,8 +1045,20 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: return loss raise ValueError(f"loss {loss} is not supported") + def on_train_end(self) -> None: + """Save item embeddings""" + self.eval() + self.item_embs = self.torch_model.item_model.get_all_embeddings() + + def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + """ + Prediction step. + Encode user sessions. + """ + encoded_sessions = self.torch_model.encode_sessions(batch, self.item_embs)[:, -1, :] + return encoded_sessions + def _xavier_normal_init(self) -> None: - """TODO""" for _, param in self.torch_model.named_parameters(): try: torch.nn.init.xavier_normal_(param.data) @@ -753,7 +1070,52 @@ def _xavier_normal_init(self) -> None: class SASRecModel(ModelBase): - """TODO""" + """ + SASRec model for i2i and u2i recommendations. + + n_blocks: int, default 1 + Number of transformer blocks. + n_heads: int, default 1 + Number of attention heads. + n_factors: int, default 128 + Latent embeddings size. + use_pos_emb: bool, default ``True`` + If ``True``, adds learnable positional encoding to session item embeddings. + dropout_rate: float, default 0.2 + Probability of a hidden unit to be zeroed. + session_max_len: int, default 32 + Maximum length of user sequence. + dataloader_num_workers: int, default 0 + Number of loader worker processes. + batch_size: int, default 128 + How many samples per batch to load. + loss: str, default "softmax" + Loss function. + lr: float, default 0.01 + Learning rate. + epochs: int, default 3 + Number of training epochs. + verbose: int, default 0 + Verbosity level. + deterministic: bool, default ``False`` + If ``True``, sets deterministic algorithms for PyTorch operations. + Use `pytorch_lightning.seed_everything` together with this parameter to fix the random state. + cpu_n_threads: int, default 0 + Number of threads to use in ranker. + trainer: Optional(Trainer), default None + Which trainer to use for training. + If trainer is None, default pytorch_lightning Trainer is created. + item_net_type: Type(ItemNetBase), default `IdEmbeddingsItemNet` + Type of network returning item enbeddings. + pos_encoding_type: Type(PositionalEncodingBase), default `LearnableInversePositionalEncoding` + Type of positional encoding. + transformer_layers_type: Type(TransformerLayersBase), default `SasRecTransformerLayers` + Type of transformer layers architecture. + data_preparator_type: Type(SessionEncoderDataPreparatorBase), default `SasRecDataPreparator` + Type of data preparator used for dataset processing and dataloader creation. + lightning_module_type: Type(SessionEncoderLightningModuleBase), default `SessionEncoderLightningModule` + Type of lightning module defining training procedure. + """ def __init__( # pylint: disable=too-many-arguments self, @@ -770,7 +1132,6 @@ def __init__( # pylint: disable=too-many-arguments epochs: int = 3, verbose: int = 0, deterministic: bool = False, - device: str = "cuda:1", cpu_n_threads: int = 0, trainer: tp.Optional[Trainer] = None, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), @@ -780,9 +1141,7 @@ def __init__( # pylint: disable=too-many-arguments lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, ): super().__init__(verbose=verbose) - self.device = torch.device(device) self.n_threads = cpu_n_threads - self.torch_model: TransformerBasedSessionEncoder self._torch_model = TransformerBasedSessionEncoder( n_blocks=n_blocks, n_factors=n_factors, @@ -795,6 +1154,7 @@ def __init__( # pylint: disable=too-many-arguments item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, ) + self.lightning_model: SessionEncoderLightningModuleBase self.lightning_module_type = lightning_module_type self.trainer: Trainer if trainer is None: @@ -821,12 +1181,13 @@ def _fit( processed_dataset = self.data_preparator.process_dataset_train(dataset) train_dataloader = self.data_preparator.get_dataloader_train(processed_dataset) - self.torch_model = deepcopy(self._torch_model) # TODO: check that it works - self.torch_model.construct_item_net(processed_dataset) + torch_model = deepcopy(self._torch_model) # TODO: check that it works + torch_model.construct_item_net(processed_dataset) + + self.lightning_model = self.lightning_module_type(torch_model, self.lr, self.loss) - lightning_model = self.lightning_module_type(self.torch_model, self.lr, self.loss) self.trainer = deepcopy(self._trainer) - self.trainer.fit(lightning_model, train_dataloader) + self.trainer.fit(self.lightning_model, train_dataloader) def _custom_transform_dataset_u2i( self, dataset: Dataset, users: ExternalIds, on_unsupported_targets: ErrorBehaviour @@ -849,47 +1210,38 @@ def _recommend_u2i( if sorted_item_ids_to_recommend is None: # TODO: move to _get_sorted_item_ids_to_recommend sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal - self.torch_model = self.torch_model.eval() - self.torch_model.to(self.device) - - # Dataset has already been filtered and adapted to known item_id_map recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) - - session_embs = [] - item_embs = self.torch_model.item_model.get_all_embeddings() # [n_items + 1, n_factors] - with torch.no_grad(): - for x_batch in tqdm.tqdm(recommend_dataloader): # TODO: from tqdm.auto import tqdm. Also check `verbose`` - x_batch = x_batch.to(self.device) # [batch_size, session_max_len] - encoded = self.torch_model.encode_sessions(x_batch, item_embs)[:, -1, :] # [batch_size, n_factors] - encoded = encoded.detach().cpu().numpy() - session_embs.append(encoded) - - user_embs = np.concatenate(session_embs, axis=0) - user_embs = user_embs[user_ids] - item_embs_np = item_embs.detach().cpu().numpy() - - ranker = ImplicitRanker( - self.u2i_dist, - user_embs, # [n_rec_users, n_factors] - item_embs_np, # [n_items + 1, n_factors] - ) - if filter_viewed: - user_items = dataset.get_user_item_matrix(include_weights=False) - ui_csr_for_filter = user_items[user_ids] + session_embs = self.trainer.predict(model=self.lightning_model, dataloaders=recommend_dataloader) + if session_embs is not None: + user_embs = np.concatenate(session_embs, axis=0) + user_embs = user_embs[user_ids] + item_embs = self.lightning_model.item_embs + item_embs_np = item_embs.detach().cpu().numpy() + + ranker = ImplicitRanker( + self.u2i_dist, + user_embs, # [n_rec_users, n_factors] + item_embs_np, # [n_items + 1, n_factors] + ) + if filter_viewed: + user_items = dataset.get_user_item_matrix(include_weights=False) + ui_csr_for_filter = user_items[user_ids] + else: + ui_csr_for_filter = None + + # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster + + user_ids_indices, all_reco_ids, all_scores = ranker.rank( + subject_ids=np.arange(user_embs.shape[0]), # n_rec_users + k=k, + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 1] + sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + num_threads=self.n_threads, + ) + all_target_ids = user_ids[user_ids_indices] else: - ui_csr_for_filter = None - - # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster - - user_ids_indices, all_reco_ids, all_scores = ranker.rank( - subject_ids=np.arange(user_embs.shape[0]), # n_rec_users - k=k, - filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 1] - sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal - num_threads=self.n_threads, - ) - all_target_ids = user_ids[user_ids_indices] - + explanation = """Received empty recommendations. Used for type-annotation""" + raise ValueError(explanation) return all_target_ids, all_reco_ids, all_scores # n_rec_users, model_internal, scores def _recommend_i2i( @@ -902,9 +1254,7 @@ def _recommend_i2i( if sorted_item_ids_to_recommend is None: sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() - self.torch_model = self.torch_model.eval() - item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 1, n_factors] - + item_embs = self.lightning_model.item_embs.detach().cpu().numpy() # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU # Should we use torch dot and topk? Should be faster @@ -922,6 +1272,6 @@ def _recommend_i2i( ) @property - def lightning_model(self) -> SessionEncoderLightningModule: - """TODO""" - return self.trainer.lightning_module + def torch_model(self) -> TransformerBasedSessionEncoder: + """Return torch model.""" + return self.lightning_model.torch_model diff --git a/tests/models/test_sasrec.py b/tests/models/test_sasrec.py new file mode 100644 index 00000000..a7af7644 --- /dev/null +++ b/tests/models/test_sasrec.py @@ -0,0 +1,584 @@ +import typing as tp +from typing import List + +import numpy as np +import pandas as pd +import pytest +import torch +from pytorch_lightning import Trainer, seed_everything + +from rectools.columns import Columns +from rectools.dataset import Dataset, IdMap, Interactions +from rectools.models.sasrec import IdEmbeddingsItemNet, SASRecDataPreparator, SASRecModel, SequenceDataset +from tests.models.utils import assert_second_fit_refits_model +from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal + + +class TestSASRecModel: + def setup_method(self) -> None: + self._seed_everything() + + def _seed_everything(self) -> None: + torch.use_deterministic_algorithms(True) + seed_everything(32, workers=True) + + @pytest.fixture + def interactions_df(self) -> pd.DataFrame: + interactions_df = pd.DataFrame( + [ + [10, 13, 1, "2021-11-30"], + [10, 11, 1, "2021-11-29"], + [10, 12, 1, "2021-11-29"], + [30, 11, 1, "2021-11-27"], + [30, 12, 2, "2021-11-26"], + [30, 15, 1, "2021-11-25"], + [40, 11, 1, "2021-11-25"], + [40, 17, 1, "2021-11-26"], + [50, 16, 1, "2021-11-25"], + [10, 14, 1, "2021-11-28"], + [10, 16, 1, "2021-11-27"], + [20, 13, 9, "2021-11-28"], + ], + columns=Columns.Interactions, + ) + return interactions_df + + @pytest.fixture + def dataset(self, interactions_df: pd.DataFrame) -> Dataset: + return Dataset.construct(interactions_df) + + @pytest.fixture + def dataset_hot_users_items(self, interactions_df: pd.DataFrame) -> Dataset: + return Dataset.construct(interactions_df[:-4]) + + @pytest.fixture + def trainer(self) -> Trainer: + return Trainer( + max_epochs=2, + min_epochs=2, + deterministic=True, + accelerator="cpu", + ) + + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 14, 13, 17, 12, 14, 13], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [10, 10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [13, 12, 14, 12, 11, 14, 12, 17, 11], + Columns.Rank: [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + ), + ), + ), + ) + # TODO: tests do not pass for multiple GPUs + def test_u2i(self, dataset: Dataset, trainer: Trainer, filter_viewed: bool, expected: pd.DataFrame) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + epochs=2, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + model.fit(dataset=dataset) + users = np.array([10, 30, 40]) + actual = model.recommend(users=users, dataset=dataset, k=3, filter_viewed=filter_viewed) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [10, 30, 30, 40], + Columns.Item: [17, 13, 17, 13], + Columns.Rank: [1, 1, 2, 1], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [10, 10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [13, 17, 11, 11, 13, 17, 17, 11, 13], + Columns.Rank: [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + ), + ), + ), + ) + def test_with_whitelist( + self, dataset: Dataset, trainer: Trainer, filter_viewed: bool, expected: pd.DataFrame + ) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + epochs=2, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + model.fit(dataset=dataset) + users = np.array([10, 30, 40]) + items_to_recommend = np.array([11, 13, 17]) + actual = model.recommend( + users=users, + dataset=dataset, + k=3, + filter_viewed=filter_viewed, + items_to_recommend=items_to_recommend, + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + + @pytest.mark.parametrize( + "filter_itself,whitelist,expected", + ( + ( + False, + None, + pd.DataFrame( + { + Columns.TargetItem: [12, 12, 12, 14, 14, 14, 17, 17, 17], + Columns.Item: [12, 17, 11, 14, 11, 13, 17, 12, 14], + Columns.Rank: [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + ), + ), + ( + True, + None, + pd.DataFrame( + { + Columns.TargetItem: [12, 12, 12, 14, 14, 14, 17, 17, 17], + Columns.Item: [17, 11, 14, 11, 13, 17, 12, 14, 11], + Columns.Rank: [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + ), + ), + ( + True, + np.array([15, 13, 14]), + pd.DataFrame( + { + Columns.TargetItem: [12, 12, 12, 14, 14, 17, 17, 17], + Columns.Item: [14, 13, 15, 13, 15, 14, 15, 13], + Columns.Rank: [1, 2, 3, 1, 2, 1, 2, 3], + } + ), + ), + ), + ) + def test_i2i( + self, + dataset: Dataset, + trainer: Trainer, + filter_itself: bool, + whitelist: tp.Optional[np.ndarray], + expected: pd.DataFrame, + ) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + epochs=2, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + model.fit(dataset=dataset) + target_items = np.array([12, 14, 17]) + actual = model.recommend_to_items( + target_items=target_items, + dataset=dataset, + k=3, + filter_itself=filter_itself, + items_to_recommend=whitelist, + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.TargetItem, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + + def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset, trainer: Trainer) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) + + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [20, 20, 20], + Columns.Item: [14, 12, 17], + Columns.Rank: [1, 2, 3], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [20, 20, 20], + Columns.Item: [13, 14, 12], + Columns.Rank: [1, 2, 3], + } + ), + ), + ), + ) + def test_recommend_for_cold_user_with_hot_item( + self, dataset: Dataset, trainer: Trainer, filter_viewed: bool, expected: pd.DataFrame + ) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + epochs=2, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + model.fit(dataset=dataset) + users = np.array([20]) + actual = model.recommend( + users=users, + dataset=dataset, + k=3, + filter_viewed=filter_viewed, + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [10, 10, 20, 20, 20], + Columns.Item: [17, 15, 14, 12, 17], + Columns.Rank: [1, 2, 1, 2, 3], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [10, 10, 10, 20, 20, 20], + Columns.Item: [13, 12, 14, 13, 14, 12], + Columns.Rank: [1, 2, 3, 1, 2, 3], + } + ), + ), + ), + ) + def test_warn_when_hot_user_has_cold_items_in_recommend( + self, dataset: Dataset, trainer: Trainer, filter_viewed: bool, expected: pd.DataFrame + ) -> None: + model = SASRecModel( + n_factors=32, + n_blocks=2, + session_max_len=3, + lr=0.001, + batch_size=4, + epochs=2, + deterministic=True, + item_net_block_types=(IdEmbeddingsItemNet,), + trainer=trainer, + ) + model.fit(dataset=dataset) + users = np.array([10, 20, 50]) + with pytest.warns() as record: + actual = model.recommend( + users=users, + dataset=dataset, + k=3, + filter_viewed=filter_viewed, + on_unsupported_targets="warn", + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + assert str(record[0].message) == "1 target users were considered cold because of missing known items" + assert ( + str(record[1].message) + == """ + Model `` doesn't support recommendations for cold users, + but some of given users are cold: they are not in the `dataset.user_id_map` + """ + ) + + +class TestSequenceDataset: + + @pytest.fixture + def interactions_df(self) -> pd.DataFrame: + interactions_df = pd.DataFrame( + [ + [10, 13, 1, "2021-11-30"], + [10, 11, 1, "2021-11-29"], + [10, 12, 4, "2021-11-29"], + [30, 11, 1, "2021-11-27"], + [30, 12, 2, "2021-11-26"], + [30, 15, 1, "2021-11-25"], + [40, 11, 1, "2021-11-25"], + [40, 17, 8, "2021-11-26"], + [50, 16, 1, "2021-11-25"], + [10, 14, 1, "2021-11-28"], + ], + columns=Columns.Interactions, + ) + return interactions_df + + @pytest.mark.parametrize( + "expected_sessions, expected_weights", + (([[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]]),), + ) + def test_from_interactions( + self, interactions_df: pd.DataFrame, expected_sessions: List[List[int]], expected_weights: List[List[float]] + ) -> None: + actual = SequenceDataset.from_interactions(interactions_df) + assert len(actual.sessions) == len(expected_sessions) + assert all( + actual_list == expected_list for actual_list, expected_list in zip(actual.sessions, expected_sessions) + ) + assert len(actual.weights) == len(expected_weights) + assert all(actual_list == expected_list for actual_list, expected_list in zip(actual.weights, expected_weights)) + + +class TestSASRecDataPreparator: + + def setup_method(self) -> None: + self._seed_everything() + + def _seed_everything(self) -> None: + torch.use_deterministic_algorithms(True) + seed_everything(32, workers=True) + + @pytest.fixture + def dataset(self) -> Dataset: + interactions_df = pd.DataFrame( + [ + [10, 13, 1, "2021-11-30"], + [10, 11, 1, "2021-11-29"], + [10, 12, 1, "2021-11-29"], + [30, 11, 1, "2021-11-27"], + [30, 12, 2, "2021-11-26"], + [30, 15, 1, "2021-11-25"], + [40, 11, 1, "2021-11-25"], + [40, 17, 1, "2021-11-26"], + [50, 16, 1, "2021-11-25"], + [10, 14, 1, "2021-11-28"], + [10, 16, 1, "2021-11-27"], + [20, 13, 9, "2021-11-28"], + ], + columns=Columns.Interactions, + ) + return Dataset.construct(interactions_df) + + @pytest.fixture + def data_preparator(self) -> SASRecDataPreparator: + return SASRecDataPreparator(session_max_len=3, batch_size=4, dataloader_num_workers=0) + + @pytest.mark.parametrize( + "expected_user_id_map, expected_item_id_map, expected_interactions", + ( + ( + IdMap.from_values([30, 40, 10]), + IdMap.from_values(["PAD", 15, 11, 12, 17, 14, 13]), + Interactions( + pd.DataFrame( + [ + [0, 1, 1.0, "2021-11-25"], + [1, 2, 1.0, "2021-11-25"], + [0, 3, 2.0, "2021-11-26"], + [1, 4, 1.0, "2021-11-26"], + [0, 2, 1.0, "2021-11-27"], + [2, 5, 1.0, "2021-11-28"], + [2, 2, 1.0, "2021-11-29"], + [2, 3, 1.0, "2021-11-29"], + [2, 6, 1.0, "2021-11-30"], + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ), + ), + ), + ), + ) + def test_process_dataset_train( + self, + dataset: Dataset, + data_preparator: SASRecDataPreparator, + expected_interactions: Interactions, + expected_item_id_map: IdMap, + expected_user_id_map: IdMap, + ) -> None: + actual = data_preparator.process_dataset_train(dataset) + assert_id_map_equal(actual.user_id_map, expected_user_id_map) + assert_id_map_equal(actual.item_id_map, expected_item_id_map) + assert_interactions_set_equal(actual.interactions, expected_interactions) + + @pytest.mark.parametrize( + "expected_user_id_map, expected_item_id_map, expected_interactions", + ( + ( + IdMap.from_values([10, 20]), + IdMap.from_values(["PAD", 15, 11, 12, 17, 14, 13]), + Interactions( + pd.DataFrame( + [ + [0, 6, 1.0, "2021-11-30"], + [0, 2, 1.0, "2021-11-29"], + [0, 3, 1.0, "2021-11-29"], + [0, 5, 1.0, "2021-11-28"], + [1, 6, 9.0, "2021-11-28"], + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ), + ), + ), + ), + ) + def test_transform_dataset_u2i( + self, + dataset: Dataset, + data_preparator: SASRecDataPreparator, + expected_interactions: Interactions, + expected_item_id_map: IdMap, + expected_user_id_map: IdMap, + ) -> None: + data_preparator.process_dataset_train(dataset) + users = [10, 20] + actual = data_preparator.transform_dataset_u2i(dataset, users) + assert_id_map_equal(actual.user_id_map, expected_user_id_map) + assert_id_map_equal(actual.item_id_map, expected_item_id_map) + assert_interactions_set_equal(actual.interactions, expected_interactions) + + @pytest.mark.parametrize( + "expected_user_id_map, expected_item_id_map, expected_interactions", + ( + ( + IdMap.from_values([10, 30, 40, 50, 20]), + IdMap.from_values(["PAD", 15, 11, 12, 17, 14, 13]), + Interactions( + pd.DataFrame( + [ + [0, 6, 1.0, "2021-11-30"], + [0, 2, 1.0, "2021-11-29"], + [0, 3, 1.0, "2021-11-29"], + [1, 2, 1.0, "2021-11-27"], + [1, 3, 2.0, "2021-11-26"], + [1, 1, 1.0, "2021-11-25"], + [2, 2, 1.0, "2021-11-25"], + [2, 4, 1.0, "2021-11-26"], + [0, 5, 1.0, "2021-11-28"], + [4, 6, 9.0, "2021-11-28"], + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ), + ), + ), + ), + ) + def test_tranform_dataset_i2i( + self, + dataset: Dataset, + data_preparator: SASRecDataPreparator, + expected_interactions: Interactions, + expected_item_id_map: IdMap, + expected_user_id_map: IdMap, + ) -> None: + data_preparator.process_dataset_train(dataset) + actual = data_preparator.transform_dataset_i2i(dataset) + assert_id_map_equal(actual.user_id_map, expected_user_id_map) + assert_id_map_equal(actual.item_id_map, expected_item_id_map) + assert_interactions_set_equal(actual.interactions, expected_interactions) + + @pytest.mark.parametrize( + "train_batch", + ( + ( + [ + torch.tensor([[5, 2, 3], [0, 1, 3], [0, 0, 2]]), + torch.tensor([[2, 3, 6], [0, 3, 2], [0, 0, 4]]), + torch.tensor([[1.0, 1.0, 1.0], [0.0, 2.0, 1.0], [0.0, 0.0, 1.0]]), + ] + ), + ), + ) + def test_get_dataloader_train( + self, dataset: Dataset, data_preparator: SASRecDataPreparator, train_batch: List + ) -> None: + dataset = data_preparator.process_dataset_train(dataset) + dataloader = data_preparator.get_dataloader_train(dataset) + actual = next(iter(dataloader)) + for i, value in enumerate(actual): + assert torch.equal(value, train_batch[i]) + + @pytest.mark.parametrize( + "recommend_batch", + ((torch.tensor([[2, 3, 6], [1, 3, 2], [0, 2, 4], [0, 0, 6]])),), + ) + def test_get_dataloader_recommend( + self, dataset: Dataset, data_preparator: SASRecDataPreparator, recommend_batch: torch.Tensor + ) -> None: + data_preparator.process_dataset_train(dataset) + dataset = data_preparator.transform_dataset_i2i(dataset) + dataloader = data_preparator.get_dataloader_recommend(dataset) + actual = next(iter(dataloader)) + assert torch.equal(actual, recommend_batch)