diff --git a/README.md b/README.md index 6d02554e..d4966821 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,14 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St | AutoInt | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | | ONN | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) | | FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) | -| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) | -| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) | -| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) | -| AFN | [AAAI 2020][Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions](https://arxiv.org/pdf/1909.03276) | +| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) | +| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) | +| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) | +| AFN | [AAAI 2020][Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions](https://arxiv.org/pdf/1909.03276) | +| SharedBottom | [arxiv 2017][An Overview of Multi-Task Learning in Deep Neural Networks](https://arxiv.org/pdf/1706.05098.pdf) | +| ESMM | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://dl.acm.org/doi/10.1145/3209978.3210104) | +| MMOE | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) | +| PLE | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) | diff --git a/deepctr_torch/models/__init__.py b/deepctr_torch/models/__init__.py index e72de07a..784134b5 100644 --- a/deepctr_torch/models/__init__.py +++ b/deepctr_torch/models/__init__.py @@ -15,4 +15,5 @@ from .ccpm import CCPM from .dien import DIEN from .din import DIN -from .afn import AFN \ No newline at end of file +from .afn import AFN +from .multitask import SharedBottom, ESMM, MMOE, PLE diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index 17e57b90..cd36340a 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -245,7 +245,13 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc y_pred = model(x).squeeze() optim.zero_grad() - loss = loss_func(y_pred, y.squeeze(), reduction='sum') + if isinstance(loss_func, list): + assert len(loss_func) == self.num_tasks,\ + "the length of `loss_func` should be equal with `self.num_tasks`" + loss = sum( + [loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(self.num_tasks)]) + else: + loss = loss_func(y_pred, y.squeeze(), reduction='sum') reg_loss = self.get_regularization_loss() total_loss = loss + reg_loss + self.aux_loss @@ -456,18 +462,24 @@ def _get_optim(self, optimizer): def _get_loss_func(self, loss): if isinstance(loss, str): - if loss == "binary_crossentropy": - loss_func = F.binary_cross_entropy - elif loss == "mse": - loss_func = F.mse_loss - elif loss == "mae": - loss_func = F.l1_loss - else: - raise NotImplementedError + loss_func = self._get_loss_func_single(loss) + elif isinstance(loss, list): + loss_func = [self._get_loss_func_single(loss_single) for loss_single in loss] else: loss_func = loss return loss_func + def _get_loss_func_single(self, loss): + if loss == "binary_crossentropy": + loss_func = F.binary_cross_entropy + elif loss == "mse": + loss_func = F.mse_loss + elif loss == "mae": + loss_func = F.l1_loss + else: + raise NotImplementedError + return loss_func + def _log_loss(self, y_true, y_pred, eps=1e-7, normalize=True, sample_weight=None, labels=None): # change eps to improve calculation accuracy return log_loss(y_true, diff --git a/deepctr_torch/models/dcnmix.py b/deepctr_torch/models/dcnmix.py index 9b0e97d4..c819a42c 100644 --- a/deepctr_torch/models/dcnmix.py +++ b/deepctr_torch/models/dcnmix.py @@ -71,9 +71,10 @@ def __init__(self, linear_feature_columns, self.add_regularization_weight( filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear) - self.add_regularization_weight(self.crossnet.U_list, l2=l2_reg_cross) - self.add_regularization_weight(self.crossnet.V_list, l2=l2_reg_cross) - self.add_regularization_weight(self.crossnet.C_list, l2=l2_reg_cross) + regularization_modules = [self.crossnet.U_list, self.crossnet.V_list, self.crossnet.C_list] + for module in regularization_modules: + self.add_regularization_weight(module, l2=l2_reg_cross) + self.to(device) def forward(self, X): diff --git a/deepctr_torch/models/dien.py b/deepctr_torch/models/dien.py index 917777f9..c31c0c9d 100644 --- a/deepctr_torch/models/dien.py +++ b/deepctr_torch/models/dien.py @@ -217,7 +217,7 @@ def forward(self, keys, keys_length, neg_keys=None): masked_keys = torch.masked_select(keys, mask.view(-1, 1, 1)).view(-1, max_length, dim) - packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length, batch_first=True, + packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length.cpu(), batch_first=True, enforce_sorted=False) packed_interests, _ = self.gru(packed_keys) interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, @@ -353,7 +353,7 @@ def forward(self, query, keys, keys_length, mask=None): query = torch.masked_select(query, mask.view(-1, 1)).view(-1, dim).unsqueeze(1) if self.gru_type == 'GRU': - packed_keys = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, enforce_sorted=False) + packed_keys = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False) packed_interests, _ = self.interest_evolution(packed_keys) interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, total_length=max_length) @@ -362,15 +362,15 @@ def forward(self, query, keys, keys_length, mask=None): elif self.gru_type == 'AIGRU': att_scores = self.attention(query, keys, keys_length.unsqueeze(1)) # [b, 1, T] interests = keys * att_scores.transpose(1, 2) # [b, T, H] - packed_interests = pack_padded_sequence(interests, lengths=keys_length, batch_first=True, + packed_interests = pack_padded_sequence(interests, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False) _, outputs = self.interest_evolution(packed_interests) outputs = outputs.squeeze(0) # [b, H] elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU': att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1) # [b, T] - packed_interests = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, + packed_interests = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False) - packed_scores = pack_padded_sequence(att_scores, lengths=keys_length, batch_first=True, + packed_scores = pack_padded_sequence(att_scores, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False) outputs = self.interest_evolution(packed_interests, packed_scores) outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=max_length) diff --git a/deepctr_torch/models/multitask/__init__.py b/deepctr_torch/models/multitask/__init__.py new file mode 100644 index 00000000..55d7eb00 --- /dev/null +++ b/deepctr_torch/models/multitask/__init__.py @@ -0,0 +1,4 @@ +from .sharedbottom import SharedBottom +from .esmm import ESMM +from .mmoe import MMOE +from .ple import PLE diff --git a/deepctr_torch/models/multitask/esmm.py b/deepctr_torch/models/multitask/esmm.py new file mode 100644 index 00000000..4a0d2fe2 --- /dev/null +++ b/deepctr_torch/models/multitask/esmm.py @@ -0,0 +1,94 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com + +Reference: + [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval. 2018.(https://dl.acm.org/doi/10.1145/3209978.3210104) +""" +import torch +import torch.nn as nn + +from ..basemodel import BaseModel +from ...inputs import combined_dnn_input +from ...layers import DNN + + +class ESMM(BaseModel): + """Instantiates the Entire Space Multi-Task Model architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN. + :param l2_reg_linear: float, L2 regularizer strength applied to linear part. + :param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector. + :param l2_reg_dnn: float, L2 regularizer strength applied to DNN. + :param init_std: float, to use as the initialize std of embedding vector. + :param seed: integer, to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN. + :param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN. + :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss or ``"regression"`` for regression loss. e.g. ['binary', 'regression']. + :param task_names: list of str, indicating the predict target of each tasks. + :param device: str, ``"cpu"`` or ``"cuda:0"``. + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. + + :return: A PyTorch model instance. + """ + + def __init__(self, dnn_feature_columns, tower_dnn_hidden_units=(256, 128), + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, + dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task_types=('binary', 'binary'), + task_names=('ctr', 'ctcvr'), device='cpu', gpus=None): + super(ESMM, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, + l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, + seed=seed, task='binary', device=device, gpus=gpus) + self.num_tasks = len(task_names) + if self.num_tasks != 2: + raise ValueError("the length of task_names must be equal to 2") + if len(dnn_feature_columns) == 0: + raise ValueError("dnn_feature_columns is null!") + if len(task_types) != self.num_tasks: + raise ValueError("num_tasks must be equal to the length of task_types") + + for task_type in task_types: + if task_type != 'binary': + raise ValueError("task must be binary in ESMM, {} is illegal".format(task_type)) + + input_dim = self.compute_input_dim(dnn_feature_columns) + + self.ctr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation, + dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) + self.cvr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation, + dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) + + self.ctr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False) + self.cvr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False) + + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.ctr_dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.cvr_dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.ctr_dnn_final_layer.weight, l2=l2_reg_dnn) + self.add_regularization_weight(self.cvr_dnn_final_layer.weight, l2=l2_reg_dnn) + self.to(device) + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) + + ctr_output = self.ctr_dnn(dnn_input) + cvr_output = self.cvr_dnn(dnn_input) + + ctr_logit = self.ctr_dnn_final_layer(ctr_output) + cvr_logit = self.cvr_dnn_final_layer(cvr_output) + + ctr_pred = self.out(ctr_logit) + cvr_pred = self.out(cvr_logit) + + ctcvr_pred = ctr_pred * cvr_pred # CTCVR = CTR * CVR + + task_outs = torch.cat([ctr_pred, ctcvr_pred], -1) + return task_outs diff --git a/deepctr_torch/models/multitask/mmoe.py b/deepctr_torch/models/multitask/mmoe.py new file mode 100644 index 00000000..c0401eb7 --- /dev/null +++ b/deepctr_torch/models/multitask/mmoe.py @@ -0,0 +1,143 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com + +Reference: + [1] Jiaqi Ma, Zhe Zhao, Xinyang Yi, et al. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts[C] (https://dl.acm.org/doi/10.1145/3219819.3220007) +""" +import torch +import torch.nn as nn + +from ..basemodel import BaseModel +from ...inputs import combined_dnn_input +from ...layers import DNN, PredictionLayer + + +class MMOE(BaseModel): + """Instantiates the Multi-gate Mixture-of-Experts architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param num_experts: integer, number of experts. + :param expert_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of expert DNN. + :param gate_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of gate DNN. + :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN. + :param l2_reg_linear: float, L2 regularizer strength applied to linear part. + :param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector. + :param l2_reg_dnn: float, L2 regularizer strength applied to DNN. + :param init_std: float, to use as the initialize std of embedding vector. + :param seed: integer, to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN. + :param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN. + :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression']. + :param task_names: list of str, indicating the predict target of each tasks. + :param device: str, ``"cpu"`` or ``"cuda:0"``. + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. + + :return: A PyTorch model instance. + """ + + def __init__(self, dnn_feature_columns, num_experts=3, expert_dnn_hidden_units=(256, 128), + gate_dnn_hidden_units=(64,), tower_dnn_hidden_units=(64,), l2_reg_linear=0.00001, + l2_reg_embedding=0.00001, l2_reg_dnn=0, + init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, + task_types=('binary', 'binary'), task_names=('ctr', 'ctcvr'), device='cpu', gpus=None): + super(MMOE, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, + l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, + seed=seed, device=device, gpus=gpus) + self.num_tasks = len(task_names) + if self.num_tasks <= 1: + raise ValueError("num_tasks must be greater than 1") + if num_experts <= 1: + raise ValueError("num_experts must be greater than 1") + if len(dnn_feature_columns) == 0: + raise ValueError("dnn_feature_columns is null!") + if len(task_types) != self.num_tasks: + raise ValueError("num_tasks must be equal to the length of task_types") + + for task_type in task_types: + if task_type not in ['binary', 'regression']: + raise ValueError("task must be binary or regression, {} is illegal".format(task_type)) + + self.num_experts = num_experts + self.task_names = task_names + self.input_dim = self.compute_input_dim(dnn_feature_columns) + self.expert_dnn_hidden_units = expert_dnn_hidden_units + self.gate_dnn_hidden_units = gate_dnn_hidden_units + self.tower_dnn_hidden_units = tower_dnn_hidden_units + + # expert dnn + self.expert_dnn = nn.ModuleList([DNN(self.input_dim, expert_dnn_hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in range(self.num_experts)]) + + # gate dnn + if len(gate_dnn_hidden_units) > 0: + self.gate_dnn = nn.ModuleList([DNN(self.input_dim, gate_dnn_hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in range(self.num_tasks)]) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.gate_dnn.named_parameters()), + l2=l2_reg_dnn) + self.gate_dnn_final_layer = nn.ModuleList( + [nn.Linear(gate_dnn_hidden_units[-1] if len(gate_dnn_hidden_units) > 0 else self.input_dim, + self.num_experts, bias=False) for _ in range(self.num_tasks)]) + + # tower dnn (task-specific) + if len(tower_dnn_hidden_units) > 0: + self.tower_dnn = nn.ModuleList( + [DNN(expert_dnn_hidden_units[-1], tower_dnn_hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in range(self.num_tasks)]) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.tower_dnn.named_parameters()), + l2=l2_reg_dnn) + self.tower_dnn_final_layer = nn.ModuleList([nn.Linear( + tower_dnn_hidden_units[-1] if len(tower_dnn_hidden_units) > 0 else expert_dnn_hidden_units[-1], 1, + bias=False) + for _ in range(self.num_tasks)]) + + self.out = nn.ModuleList([PredictionLayer(task) for task in task_types]) + + regularization_modules = [self.expert_dnn, self.gate_dnn_final_layer, self.tower_dnn_final_layer] + for module in regularization_modules: + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], module.named_parameters()), l2=l2_reg_dnn) + self.to(device) + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) + + # expert dnn + expert_outs = [] + for i in range(self.num_experts): + expert_out = self.expert_dnn[i](dnn_input) + expert_outs.append(expert_out) + expert_outs = torch.stack(expert_outs, 1) # (bs, num_experts, dim) + + # gate dnn + mmoe_outs = [] + for i in range(self.num_tasks): + if len(self.gate_dnn_hidden_units) > 0: + gate_dnn_out = self.gate_dnn[i](dnn_input) + gate_dnn_out = self.gate_dnn_final_layer[i](gate_dnn_out) + else: + gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input) + gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), expert_outs) # (bs, 1, dim) + mmoe_outs.append(gate_mul_expert.squeeze()) + + # tower dnn (task-specific) + task_outs = [] + for i in range(self.num_tasks): + if len(self.tower_dnn_hidden_units) > 0: + tower_dnn_out = self.tower_dnn[i](mmoe_outs[i]) + tower_dnn_logit = self.tower_dnn_final_layer[i](tower_dnn_out) + else: + tower_dnn_logit = self.tower_dnn_final_layer[i](mmoe_outs[i]) + output = self.out[i](tower_dnn_logit) + task_outs.append(output) + task_outs = torch.cat(task_outs, -1) + return task_outs diff --git a/deepctr_torch/models/multitask/ple.py b/deepctr_torch/models/multitask/ple.py new file mode 100644 index 00000000..bc8a06fb --- /dev/null +++ b/deepctr_torch/models/multitask/ple.py @@ -0,0 +1,219 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com + +Reference: + [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//Fourteenth ACM Conference on Recommender Systems. 2020.(https://dl.acm.org/doi/10.1145/3383313.3412236) +""" +import torch +import torch.nn as nn + +from ..basemodel import BaseModel +from ...inputs import combined_dnn_input +from ...layers import DNN, PredictionLayer + + +class PLE(BaseModel): + """Instantiates the multi level of Customized Gate Control of Progressive Layered Extraction architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param shared_expert_num: integer, number of task-shared experts. + :param specific_expert_num: integer, number of task-specific experts. + :param num_levels: integer, number of CGC levels. + :param expert_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of expert DNN. + :param gate_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of gate DNN. + :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN. + :param l2_reg_linear: float, L2 regularizer strength applied to linear part. + :param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector. + :param l2_reg_dnn: float, L2 regularizer strength applied to DNN. + :param init_std: float, to use as the initialize std of embedding vector. + :param seed: integer, to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN. + :param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN. + :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression'] + :param task_names: list of str, indicating the predict target of each tasks. + :param device: str, ``"cpu"`` or ``"cuda:0"``. + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. + + :return: A PyTorch model instance. + """ + + def __init__(self, dnn_feature_columns, shared_expert_num=1, specific_expert_num=1, num_levels=2, + expert_dnn_hidden_units=(256, 128), gate_dnn_hidden_units=(64,), tower_dnn_hidden_units=(64,), + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, + dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task_types=('binary', 'binary'), + task_names=('ctr', 'ctcvr'), device='cpu', gpus=None): + super(PLE, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, + l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, + seed=seed, device=device, gpus=gpus) + self.num_tasks = len(task_names) + if self.num_tasks <= 1: + raise ValueError("num_tasks must be greater than 1!") + if len(dnn_feature_columns) == 0: + raise ValueError("dnn_feature_columns is null!") + if len(task_types) != self.num_tasks: + raise ValueError("num_tasks must be equal to the length of task_types") + + for task_type in task_types: + if task_type not in ['binary', 'regression']: + raise ValueError("task must be binary or regression, {} is illegal".format(task_type)) + + self.specific_expert_num = specific_expert_num + self.shared_expert_num = shared_expert_num + self.num_levels = num_levels + self.task_names = task_names + self.input_dim = self.compute_input_dim(dnn_feature_columns) + self.expert_dnn_hidden_units = expert_dnn_hidden_units + self.gate_dnn_hidden_units = gate_dnn_hidden_units + self.tower_dnn_hidden_units = tower_dnn_hidden_units + + def multi_module_list(num_level, num_tasks, expert_num, inputs_dim_level0, inputs_dim_not_level0, hidden_units): + return nn.ModuleList( + [nn.ModuleList([nn.ModuleList([DNN(inputs_dim_level0 if level_num == 0 else inputs_dim_not_level0, + hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in + range(expert_num)]) + for _ in range(num_tasks)]) for level_num in range(num_level)]) + + # 1. experts + # task-specific experts + self.specific_experts = multi_module_list(self.num_levels, self.num_tasks, self.specific_expert_num, + self.input_dim, expert_dnn_hidden_units[-1], expert_dnn_hidden_units) + + # shared experts + self.shared_experts = multi_module_list(self.num_levels, 1, self.specific_expert_num, + self.input_dim, expert_dnn_hidden_units[-1], expert_dnn_hidden_units) + + # 2. gates + # gates for task-specific experts + specific_gate_output_dim = self.specific_expert_num + self.shared_expert_num + if len(gate_dnn_hidden_units) > 0: + self.specific_gate_dnn = multi_module_list(self.num_levels, self.num_tasks, 1, + self.input_dim, expert_dnn_hidden_units[-1], + gate_dnn_hidden_units) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.specific_gate_dnn.named_parameters()), + l2=l2_reg_dnn) + self.specific_gate_dnn_final_layer = nn.ModuleList( + [nn.ModuleList([nn.Linear( + gate_dnn_hidden_units[-1] if len(gate_dnn_hidden_units) > 0 else self.input_dim if level_num == 0 else + expert_dnn_hidden_units[-1], specific_gate_output_dim, bias=False) + for _ in range(self.num_tasks)]) for level_num in range(self.num_levels)]) + + # gates for shared experts + shared_gate_output_dim = self.num_tasks * self.specific_expert_num + self.shared_expert_num + if len(gate_dnn_hidden_units) > 0: + self.shared_gate_dnn = nn.ModuleList([DNN(self.input_dim if level_num == 0 else expert_dnn_hidden_units[-1], + gate_dnn_hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for level_num in + range(self.num_levels)]) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.shared_gate_dnn.named_parameters()), + l2=l2_reg_dnn) + self.shared_gate_dnn_final_layer = nn.ModuleList( + [nn.Linear( + gate_dnn_hidden_units[-1] if len(gate_dnn_hidden_units) > 0 else self.input_dim if level_num == 0 else + expert_dnn_hidden_units[-1], shared_gate_output_dim, bias=False) + for level_num in range(self.num_levels)]) + + # 3. tower dnn (task-specific) + if len(tower_dnn_hidden_units) > 0: + self.tower_dnn = nn.ModuleList( + [DNN(expert_dnn_hidden_units[-1], tower_dnn_hidden_units, activation=dnn_activation, + l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in range(self.num_tasks)]) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.tower_dnn.named_parameters()), + l2=l2_reg_dnn) + self.tower_dnn_final_layer = nn.ModuleList([nn.Linear( + tower_dnn_hidden_units[-1] if len(tower_dnn_hidden_units) > 0 else expert_dnn_hidden_units[-1], 1, + bias=False) + for _ in range(self.num_tasks)]) + + self.out = nn.ModuleList([PredictionLayer(task) for task in task_types]) + + regularization_modules = [self.specific_experts, self.shared_experts, self.specific_gate_dnn_final_layer, + self.shared_gate_dnn_final_layer, self.tower_dnn_final_layer] + for module in regularization_modules: + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], module.named_parameters()), l2=l2_reg_dnn) + self.to(device) + + # a single cgc Layer + def cgc_net(self, inputs, level_num): + # inputs: [task1, task2, ... taskn, shared task] + + # 1. experts + # task-specific experts + specific_expert_outputs = [] + for i in range(self.num_tasks): + for j in range(self.specific_expert_num): + specific_expert_output = self.specific_experts[level_num][i][j](inputs[i]) + specific_expert_outputs.append(specific_expert_output) + + # shared experts + shared_expert_outputs = [] + for k in range(self.shared_expert_num): + shared_expert_output = self.shared_experts[level_num][0][k](inputs[-1]) + shared_expert_outputs.append(shared_expert_output) + + # 2. gates + # gates for task-specific experts + cgc_outs = [] + for i in range(self.num_tasks): + # concat task-specific expert and task-shared expert + cur_experts_outputs = specific_expert_outputs[ + i * self.specific_expert_num:(i + 1) * self.specific_expert_num] + shared_expert_outputs + cur_experts_outputs = torch.stack(cur_experts_outputs, 1) + + # gate dnn + if len(self.gate_dnn_hidden_units) > 0: + gate_dnn_out = self.specific_gate_dnn[level_num][i][0](inputs[i]) + gate_dnn_out = self.specific_gate_dnn_final_layer[level_num][i](gate_dnn_out) + else: + gate_dnn_out = self.specific_gate_dnn_final_layer[level_num][i](inputs[i]) + gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim) + cgc_outs.append(gate_mul_expert.squeeze()) + + # gates for shared experts + cur_experts_outputs = specific_expert_outputs + shared_expert_outputs + cur_experts_outputs = torch.stack(cur_experts_outputs, 1) + + if len(self.gate_dnn_hidden_units) > 0: + gate_dnn_out = self.shared_gate_dnn[level_num](inputs[-1]) + gate_dnn_out = self.shared_gate_dnn_final_layer[level_num](gate_dnn_out) + else: + gate_dnn_out = self.shared_gate_dnn_final_layer[level_num](inputs[-1]) + gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim) + cgc_outs.append(gate_mul_expert.squeeze()) + + return cgc_outs + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) + + # repeat `dnn_input` for several times to generate cgc input + ple_inputs = [dnn_input] * (self.num_tasks + 1) # [task1, task2, ... taskn, shared task] + ple_outputs = [] + for i in range(self.num_levels): + ple_outputs = self.cgc_net(inputs=ple_inputs, level_num=i) + ple_inputs = ple_outputs + + # tower dnn (task-specific) + task_outs = [] + for i in range(self.num_tasks): + if len(self.tower_dnn_hidden_units) > 0: + tower_dnn_out = self.tower_dnn[i](ple_outputs[i]) + tower_dnn_logit = self.tower_dnn_final_layer[i](tower_dnn_out) + else: + tower_dnn_logit = self.tower_dnn_final_layer[i](ple_outputs[i]) + output = self.out[i](tower_dnn_logit) + task_outs.append(output) + task_outs = torch.cat(task_outs, -1) + return task_outs diff --git a/deepctr_torch/models/multitask/sharedbottom.py b/deepctr_torch/models/multitask/sharedbottom.py new file mode 100644 index 00000000..9a8f7de4 --- /dev/null +++ b/deepctr_torch/models/multitask/sharedbottom.py @@ -0,0 +1,104 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com + +Reference: + [1] Ruder S. An overview of multi-task learning in deep neural networks[J]. arXiv preprint arXiv:1706.05098, 2017.(https://arxiv.org/pdf/1706.05098.pdf) +""" +import torch +import torch.nn as nn + +from ..basemodel import BaseModel +from ...inputs import combined_dnn_input +from ...layers import DNN, PredictionLayer + + +class SharedBottom(BaseModel): + """Instantiates the SharedBottom multi-task learning Network architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param bottom_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of shared bottom DNN. + :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN. + :param l2_reg_linear: float, L2 regularizer strength applied to linear part + :param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector + :param l2_reg_dnn: float, L2 regularizer strength applied to DNN + :param init_std: float, to use as the initialize std of embedding vector + :param seed: integer, to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN + :param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN + :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss or ``"regression"`` for regression loss. e.g. ['binary', 'regression'] + :param task_names: list of str, indicating the predict target of each tasks + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. + + :return: A PyTorch model instance. + """ + + def __init__(self, dnn_feature_columns, bottom_dnn_hidden_units=(256, 128), tower_dnn_hidden_units=(64,), + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, + dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task_types=('binary', 'binary'), + task_names=('ctr', 'ctcvr'), device='cpu', gpus=None): + super(SharedBottom, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, + l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, + init_std=init_std, seed=seed, device=device, gpus=gpus) + self.num_tasks = len(task_names) + if self.num_tasks <= 1: + raise ValueError("num_tasks must be greater than 1") + if len(dnn_feature_columns) == 0: + raise ValueError("dnn_feature_columns is null!") + if len(task_types) != self.num_tasks: + raise ValueError("num_tasks must be equal to the length of task_types") + + for task_type in task_types: + if task_type not in ['binary', 'regression']: + raise ValueError("task must be binary or regression, {} is illegal".format(task_type)) + + self.task_names = task_names + self.input_dim = self.compute_input_dim(dnn_feature_columns) + self.bottom_dnn_hidden_units = bottom_dnn_hidden_units + self.tower_dnn_hidden_units = tower_dnn_hidden_units + + self.bottom_dnn = DNN(self.input_dim, bottom_dnn_hidden_units, activation=dnn_activation, + dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) + if len(self.tower_dnn_hidden_units) > 0: + self.tower_dnn = nn.ModuleList( + [DNN(bottom_dnn_hidden_units[-1], tower_dnn_hidden_units, activation=dnn_activation, + dropout_rate=dnn_dropout, use_bn=dnn_use_bn, + init_std=init_std, device=device) for _ in range(self.num_tasks)]) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.tower_dnn.named_parameters()), + l2=l2_reg_dnn) + self.tower_dnn_final_layer = nn.ModuleList([nn.Linear( + tower_dnn_hidden_units[-1] if len(self.tower_dnn_hidden_units) > 0 else bottom_dnn_hidden_units[-1], 1, + bias=False) for _ in range(self.num_tasks)]) + + self.out = nn.ModuleList([PredictionLayer(task) for task in task_types]) + + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.bottom_dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.tower_dnn_final_layer.named_parameters()), + l2=l2_reg_dnn) + self.to(device) + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) + shared_bottom_output = self.bottom_dnn(dnn_input) + + # tower dnn (task-specific) + task_outs = [] + for i in range(self.num_tasks): + if len(self.tower_dnn_hidden_units) > 0: + tower_dnn_out = self.tower_dnn[i](shared_bottom_output) + tower_dnn_logit = self.tower_dnn_final_layer[i](tower_dnn_out) + else: + tower_dnn_logit = self.tower_dnn_final_layer[i](shared_bottom_output) + output = self.out[i](tower_dnn_logit) + task_outs.append(output) + task_outs = torch.cat(task_outs, -1) + return task_outs diff --git a/docs/pics/multitaskmodels/ESMM.png b/docs/pics/multitaskmodels/ESMM.png new file mode 100644 index 00000000..49f4819a Binary files /dev/null and b/docs/pics/multitaskmodels/ESMM.png differ diff --git a/docs/pics/multitaskmodels/MMOE.png b/docs/pics/multitaskmodels/MMOE.png new file mode 100644 index 00000000..80566f7a Binary files /dev/null and b/docs/pics/multitaskmodels/MMOE.png differ diff --git a/docs/pics/multitaskmodels/PLE.png b/docs/pics/multitaskmodels/PLE.png new file mode 100644 index 00000000..41cc0c0b Binary files /dev/null and b/docs/pics/multitaskmodels/PLE.png differ diff --git a/docs/pics/multitaskmodels/SharedBottom.png b/docs/pics/multitaskmodels/SharedBottom.png new file mode 100644 index 00000000..38d811a2 Binary files /dev/null and b/docs/pics/multitaskmodels/SharedBottom.png differ diff --git a/docs/source/Features.md b/docs/source/Features.md index f7bc9827..fc521726 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -271,6 +271,55 @@ Adaptive Factorization Network (AFN) can learn arbitrary-order cross features ad [Cheng, W., Shen, Y. and Huang, L. 2020. Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions. Proceedings of the AAAI Conference on Artificial Intelligence. 34, 04 (Apr. 2020), 3609-3616.](https://arxiv.org/pdf/1909.03276) +## MultiTask Models + +### SharedBottom + +Hard parameter sharing is the most commonly used approach to MTL in neural networks. It is generally applied by sharing the hidden layers between all tasks, while keeping several task-specific output layers. + +[**SharedBottom Model API**](./deepctr_torch.models.multitask.sharedbottom.html) + +![SharedBottom](../pics/multitaskmodels/SharedBottom.png) + +[Ruder S. An overview of multi-task learning in deep neural networks[J]. arXiv preprint arXiv:1706.05098, 2017.](https://arxiv.org/pdf/1706.05098.pdf) + + +### ESMM(Entire Space Multi-task Model) + +ESMM models CVR in a brand-new perspective by making good use of sequential pattern of user actions, i.e., impression → +click → conversion. The proposed Entire Space Multi-task Model (ESMM) can eliminate the two problems simultaneously by +i) modeling CVR directly over the entire space, ii) employing a feature representation transfer learning strategy. + +[**ESMM Model API**](./deepctr_torch.models.multitask.esmm.html) + +![ESMM](../pics/multitaskmodels/ESMM.png) + +[Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval. 2018.](https://dl.acm.org/doi/10.1145/3209978.3210104) + +### MMOE(Multi-gate Mixture-of-Experts) + +Multi-gate Mixture-of-Experts (MMoE) explicitly learns to model task relationships from data. We adapt the Mixture-of- +Experts (MoE) structure to multi-task learning by sharing the expert submodels across all tasks, while also having a +gating network trained to optimize each task. + +[**MMOE Model API**](./deepctr_torch.models.multitask.mmoe.html) + +![MMOE](../pics/multitaskmodels/MMOE.png) + +[Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2018.](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) + +### PLE(Progressive Layered Extraction) + +PLE separates shared components and task-specific components explicitly and adopts a progressive rout- ing mechanism to +extract and separate deeper semantic knowledge gradually, improving efficiency of joint representation learning and +information routing across tasks in a general setup. + +[**PLE Model API**](./deepctr_torch.models.multitask.ple.html) + +![PLE](../pics/multitaskmodels/PLE.png) + +[Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//Fourteenth ACM Conference on Recommender Systems. 2020.](https://dl.acm.org/doi/10.1145/3383313.3412236) + ## Layers The models of deepctr are modular, diff --git a/docs/source/deepctr_torch.models.multitask.esmm.rst b/docs/source/deepctr_torch.models.multitask.esmm.rst new file mode 100644 index 00000000..b8e09bad --- /dev/null +++ b/docs/source/deepctr_torch.models.multitask.esmm.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.multitask.esmm module +============================= + +.. automodule:: deepctr_torch.models.multitask.esmm + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.multitask.mmoe.rst b/docs/source/deepctr_torch.models.multitask.mmoe.rst new file mode 100644 index 00000000..385a082a --- /dev/null +++ b/docs/source/deepctr_torch.models.multitask.mmoe.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.multitask.mmoe module +============================= + +.. automodule:: deepctr_torch.models.multitask.mmoe + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.multitask.ple.rst b/docs/source/deepctr_torch.models.multitask.ple.rst new file mode 100644 index 00000000..a8a8a843 --- /dev/null +++ b/docs/source/deepctr_torch.models.multitask.ple.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.multitask.ple module +============================= + +.. automodule:: deepctr_torch.models.multitask.ple + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.multitask.sharedbottom.rst b/docs/source/deepctr_torch.models.multitask.sharedbottom.rst new file mode 100644 index 00000000..4977c75b --- /dev/null +++ b/docs/source/deepctr_torch.models.multitask.sharedbottom.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.multitask.sharedbottom module +============================= + +.. automodule:: deepctr_torch.models.multitask.sharedbottom + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/examples/byterec_sample.txt b/examples/byterec_sample.txt new file mode 100644 index 00000000..d27740ea --- /dev/null +++ b/examples/byterec_sample.txt @@ -0,0 +1,200 @@ +37448 115 567569 44888 42 0 0 0 1699 43981 53085738314 9 +8623 82 1209192 10098 106 0 1 0 -1 11996 53086444998 8 +9629 31 1209193 184752 109 0 1 0 -1 32093 53085591140 5 +52799 175 1209194 109629 101 0 1 0 -1 33106 53085915481 6 +38008 -1 1209195 456237 11 1 0 1 56 18558 53085805030 9 +51750 154 234989 33229 9 0 1 0 -1 28234 53085437678 8 +57406 226 71963 34917 110 0 1 0 30276 36314 53086303187 28 +39056 59 1209207 456240 18 0 0 0 43 7471 53086422631 9 +37584 68 75987 1028 14 0 0 0 -1 19348 53086345061 18 +47304 41 16118 573 7 0 0 0 -1 15230 53085412854 11 +43834 100 353335 122030 115 0 0 0 315 1098 53086461802 8 +26244 -1 1209208 327068 86 1 0 0 -1 1623 53069870239 10 +34398 116 277319 123194 1 0 0 0 3605 4950 53086083610 14 +39727 -1 533992 28924 113 1 1 0 8506 7506 53085835593 6 +35164 150 284645 34379 52 0 0 0 5866 24579 53086330078 9 +53164 185 1209297 22618 68 0 0 0 -1 35159 53086017683 20 +53211 -1 1209298 6210 4 1 0 0 13655 31974 53086329526 20 +18179 304 118269 2589 7 0 1 0 -1 10068 53086247022 2 +8439 98 9635 297 157 0 0 0 -1 12773 53086336481 18 +3218 133 3985 47023 158 0 1 1 -1 24017 53086416149 7 +57427 123 1209299 456252 159 0 0 0 -1 44018 53086468209 21 +985 10 1209300 92339 6 0 0 0 82 147 53085562522 21 +16888 45 1209301 456253 160 0 0 0 312 17884 53086425402 21 +11410 185 1209302 456254 161 3 0 0 -1 6100 53084215934 10 +17404 6 1209303 4562 9 0 1 0 -1 19947 53086205945 10 +7838 -1 1209375 62468 10 1 0 0 1917 7483 53083790953 10 +57439 92 1209376 456266 7 0 0 0 -1 44045 53086449335 24 +30928 30 635107 32806 71 0 1 0 649 34737 53085422653 10 +12871 210 1058207 52058 35 0 0 0 -1 14134 53086198009 10 +29445 83 1209377 14537 172 0 1 0 -1 44046 53086272167 10 +19287 274 1209378 84914 23 0 1 0 -1 35997 53086437639 10 +11262 105 1209379 88368 88 0 1 0 272 44047 53084643988 6 +32398 116 1209380 456267 25 0 0 0 342 231 53085257783 9 +55836 91 9392 14327 39 0 1 0 -1 33478 53086353626 10 +23647 -1 1209381 76231 4 1 0 0 -1 923 53085252027 9 +2102 25 472014 6344 175 0 0 0 -1 5462 53086419257 10 +9208 -1 1209382 12718 -1 1 0 0 -1 19610 53071237795 10 +31364 335 1209383 6099 46 0 1 0 -1 44048 53086179673 9 +31037 317 7585 43365 35 0 0 0 -1 12371 53086439000 20 +9994 16 307411 290635 35 0 0 0 550 11794 53086424944 10 +17365 99 28335 32396 23 0 1 0 -1 44049 53086266890 10 +15067 225 1209384 16575 162 0 0 0 307 9127 53086468766 9 +20723 272 1209385 79008 31 0 1 0 -1 8266 53086015674 7 +32775 115 1209386 143408 28 0 0 0 -1 44050 53085300815 21 +42384 -1 1209387 5668 -1 1 1 0 -1 44051 53055849125 3 +4533 138 60281 2542 136 0 0 0 -1 17066 53086090800 24 +19065 15 31236 1956 68 0 0 0 -1 12938 53085858393 10 +19499 101 199901 73569 133 0 0 0 -1 23877 53084771534 9 +30729 137 1209388 3889 12 0 0 1 -1 14808 53085471904 9 +31837 108 750117 8782 17 0 0 0 12802 15492 53085208553 10 +12125 84 23090 60697 31 0 1 0 1270 2612 53085415597 11 +54922 83 968506 28232 73 0 0 0 -1 44042 53086368206 11 +6635 96 215812 49835 137 0 0 0 1033 567 53085778100 7 +49255 143 1209389 8130 159 0 0 0 -1 21682 53086298887 9 +1218 83 1209390 40809 39 0 1 0 -1 44052 53086303117 9 +1775 28 301403 44233 176 0 1 0 -1 30899 53086010336 10 +36105 315 1209391 456268 2 0 1 0 298 38662 53085493004 9 +647 -1 212199 286 32 2 0 0 -1 5444 53086350863 10 +19106 -1 1209392 75256 64 1 0 0 -1 6116 53086205060 9 +1978 21 287768 50266 8 0 0 0 -1 32762 53086446588 11 +13574 -1 1209393 167773 33 2 1 1 -1 11752 53086427692 9 +57391 252 776590 63565 35 0 1 0 -1 43954 53086261284 10 +38830 41 922462 138593 34 0 0 0 1273 7141 53086335332 9 +30928 30 1209394 456269 9 0 0 0 -1 34737 53086362499 9 +12753 167 1209395 91389 61 0 1 0 -1 11354 53086384603 9 +42526 49 3108 1395 79 0 0 0 -1 44053 53086113483 7 +7661 57 428746 254584 27 0 0 0 1731 44054 53085248033 10 +21624 52 1209396 39066 166 0 1 0 -1 36739 53084342763 9 +57440 14 337334 88583 177 0 0 0 100 44055 53084815045 8 +1365 24 1209397 3381 113 0 0 0 -1 18304 53086356080 20 +80 77 1209398 456270 58 3 0 0 110 911 53086446912 21 +18841 97 1209399 456271 20 0 0 0 43 44056 53086439064 9 +32036 26 1209400 456272 63 3 0 0 -1 9514 53086469429 9 +32568 -1 1209401 32294 6 1 1 0 -1 21635 53084290196 10 +22844 -1 1209402 12802 19 1 1 0 -1 27448 53086452714 10 +38500 160 1209403 208618 75 3 0 0 3618 29488 53085128111 21 +31273 -1 1209404 63604 29 1 0 0 307 9157 53083482331 10 +2019 106 247632 2892 123 0 1 0 -1 37121 53086367808 7 +10864 299 291769 80265 36 0 1 0 -1 6388 53085389657 10 +42720 256 1096011 7390 22 0 0 0 -1 1449 53084879625 5 +21369 13 652088 122701 140 0 1 0 -1 19027 53086365579 14 +24798 5 1209405 409220 178 0 1 0 -1 25775 53086334056 11 +49011 8 1209406 33843 116 0 1 0 -1 20926 53086369811 11 +20705 45 47309 4986 41 0 0 0 -1 17161 53086282125 4 +1312 6 629875 15298 42 0 0 0 -1 3962 53085909148 9 +31592 89 1209407 23327 6 0 0 0 -1 152 53086431091 9 +53006 113 196989 456273 42 0 0 0 25 34626 53086263205 6 +45282 199 1040419 201862 70 0 0 0 -1 7409 53086088571 9 +44935 138 456399 62422 41 0 1 0 -1 5988 53086343663 2 +31634 140 1209408 42712 9 0 1 0 5735 36610 53086116637 9 +5485 83 74911 14829 15 0 0 0 -1 7969 53085559541 9 +2517 -1 1209409 28180 0 1 0 0 -1 17998 53086006038 10 +45018 -1 237328 5756 179 1 0 0 -1 6337 53085468650 3 +39485 134 1209410 456274 44 0 1 0 46817 7480 53084601824 41 +3697 30 294287 1487 -1 0 0 0 -1 8636 53086432193 9 +31549 258 1209424 62106 137 0 0 0 -1 9790 53086460133 9 +31419 -1 1209425 633 44 1 0 0 422 20312 53085475408 13 +8492 -1 525550 247228 39 1 0 0 -1 498 53085506870 9 +57441 -1 1209426 9443 1 1 0 0 -1 44058 53083399810 10 +20852 -1 1209427 36389 125 1 0 0 -1 1097 53085425555 10 +3207 19 1209428 11987 137 0 0 0 59 30580 53086439316 10 +14829 253 188983 56353 33 0 1 0 -1 9791 53086345485 4 +31097 144 16615 11947 133 0 0 0 -1 34702 53086291770 10 +1514 -1 1209429 27816 -1 1 0 0 -1 12960 53081783582 10 +6756 16 182372 146825 54 0 0 0 3146 41922 53086263069 10 +34245 114 8637 8679 15 0 0 0 2121 5325 53086436720 10 +57442 129 840935 133609 85 0 1 0 9680 44059 53085682272 5 +20781 271 626395 16190 144 0 1 0 1348 7821 53085210580 10 +57443 -1 1209430 10433 185 1 1 0 216 44060 53082870030 10 +16703 167 71277 6969 3 0 1 0 2438 5518 53085304024 8 +22709 24 1209431 155935 85 0 0 0 -1 2237 53086425906 10 +32562 156 1209432 23293 186 0 0 0 1866 25093 53086423410 42 +35065 115 1209433 377867 187 0 1 0 -1 31413 53085859250 10 +30691 221 1209434 88230 -1 0 0 0 -1 228 53084110855 10 +34689 100 649949 273510 86 0 0 0 202 41375 53085843068 10 +30906 100 773087 3242 52 0 0 0 36 31605 53086437860 7 +24714 28 1209435 402771 188 0 1 0 930 44061 53085636881 10 +14763 -1 1209436 38473 85 1 0 0 280 7445 53085228070 9 +28664 19 187548 7514 32 0 1 0 20 41854 53086416011 19 +36484 99 309958 55783 189 0 0 0 -1 43600 53085921061 20 +10194 233 402256 70750 85 0 0 0 -1 9603 53084990495 8 +24385 -1 1209469 415992 8 1 0 0 -1 29287 53086409465 10 +32141 180 1209470 456275 112 0 0 0 25 7493 53086420014 10 +46654 73 79505 6443 39 0 0 0 416 12870 53086447179 11 +48872 -1 1209471 143 74 1 0 0 -1 20498 53086190882 20 +3851 54 1209472 54584 197 0 0 0 222 34862 53086449220 10 +16642 30 1209473 21168 3 0 0 0 1692 33912 53084085100 10 +57447 217 868717 213177 88 0 1 0 16814 33754 53086253173 10 +37650 51 288580 3310 177 0 0 0 -1 18942 53086433092 9 +670 -1 571891 16897 7 1 1 0 -1 44074 53086451941 9 +16698 -1 1209474 456276 83 1 0 1 1590 39965 53084815349 9 +8255 89 1209475 51415 22 0 1 0 -1 5807 53085049022 10 +72 51 1196155 13315 29 0 1 0 3565 1087 53086433711 4 +48647 159 137161 59611 114 0 1 0 11588 19767 53084540124 10 +57448 66 1209476 224774 35 0 0 0 317 34873 53086433953 10 +22835 106 686629 67037 64 0 0 0 -1 30409 53085752685 10 +2327 8 53968 5312 35 0 1 0 -1 32615 53085730895 9 +5045 100 340148 4866 108 0 1 0 273 7220 53086335739 10 +18658 45 215302 900 0 0 0 0 -1 19487 53085418824 10 +24993 109 1209477 116919 5 0 1 0 12 13470 53085684437 21 +57449 214 8317 7935 79 0 0 0 -1 44075 53085912160 9 +6055 217 68275 33192 103 0 0 0 -1 38530 53086282873 21 +57450 48 157990 546 135 0 0 0 -1 44076 53085865033 10 +57451 129 176945 27406 28 0 0 0 -1 44077 53085395234 10 +8541 12 1209478 31095 42 0 1 0 -1 12080 53086436156 10 +3406 -1 1209479 46849 68 2 0 0 -1 19356 53086375995 2 +5798 129 493064 41821 94 0 0 0 595 33618 53086212382 19 +21624 52 94301 11708 22 0 1 0 5291 36739 53085832387 10 +2083 173 1209480 40611 122 0 0 0 -1 31475 53085078625 10 +57417 99 143241 13263 118 0 0 0 -1 43998 53086091383 9 +7918 207 12620 31965 42 0 1 0 -1 37035 53085818765 7 +10235 269 1209481 78850 96 0 1 0 -1 6199 53086431251 7 +57452 297 1209482 2738 112 0 0 0 -1 44078 53084402300 10 +12986 266 114237 6756 20 0 0 0 404 3438 53086373539 9 +48307 115 1209483 456277 162 0 0 0 -1 18643 53084914127 4 +11130 12 2633 42363 102 0 0 0 -1 6636 53086421291 9 +12940 268 1137846 69512 28 0 0 0 -1 7066 53086347355 8 +57227 129 541042 24282 198 0 1 0 -1 44079 53084948734 10 +3588 259 161900 52048 73 0 1 0 -1 2449 53086445895 9 +57453 5 1209484 60297 -1 0 0 0 -1 31390 53086372063 10 +32933 20 1209485 56832 137 0 0 0 22 2838 53086456303 9 +2125 24 1209486 41659 52 0 0 0 -1 8226 53086452225 9 +1615 99 1534 12106 69 0 1 0 2261 21629 53085829649 14 +33863 68 1209487 1283 0 0 1 0 -1 28277 53085988775 9 +5933 134 2771 27411 64 0 1 0 -1 850 53086434217 9 +9379 113 12803 13736 182 0 0 0 -1 15031 53086335760 10 +29831 -1 1209488 231377 -1 1 0 0 -1 16942 53078504296 9 +49183 150 23929 36830 41 0 0 0 273 21457 53085651152 21 +49211 7 1209489 33166 199 0 1 0 -1 21543 53086381916 7 +48289 -1 1209490 8925 128 1 0 0 8509 18585 53086178859 9 +907 45 643932 33139 9 0 0 0 -1 32141 53086420908 42 +32183 232 81546 19108 73 0 0 0 206 2648 53086376751 7 +31997 80 1209491 3709 49 0 0 0 -1 11096 53085741857 10 +1396 -1 1209492 3962 -1 1 0 0 -1 44080 53081166934 10 +31956 70 584468 51539 106 0 1 0 33213 1824 53085312711 10 +5732 -1 1209493 456278 200 1 0 0 -1 5100 53086341408 9 +10496 22 1209494 290120 113 0 1 0 -1 26483 53086357296 9 +40411 -1 60281 2542 136 1 0 0 -1 22135 53086090800 24 +2945 -1 1209495 65069 7 1 1 0 -1 4727 53086394412 14 +57454 123 496 28162 201 0 0 0 -1 44081 53086458298 9 +5296 89 1209496 22908 178 0 0 0 -1 2538 53084054634 9 +2434 97 837808 20224 177 0 0 0 404 1219 53086439675 10 +14333 120 15924 2813 112 0 1 0 -1 1997 53086136768 9 +2062 41 87999 313825 39 0 1 0 120 10512 53085654064 21 +50351 70 1209497 456279 202 0 1 0 -1 24714 53086350722 4 +35535 81 1209502 456281 176 0 0 0 622 34148 53086446596 10 +26225 6 801784 168777 203 0 1 0 3454 18863 53084813379 10 +48951 293 39466 38503 83 0 1 0 12778 20749 53085121817 13 +1234 6 1209503 3547 22 0 1 1 468 7242 53086415297 10 +41026 -1 1209504 54978 204 1 0 0 64952 31155 53086354237 10 +17717 73 4474 29502 4 0 1 0 -1 7246 53085654373 39 +35604 57 151486 21718 45 0 1 0 -1 13567 53085652609 10 +7520 6 1209505 178973 103 0 0 0 282 33342 53086446872 10 +26835 185 14168 5813 4 0 1 0 -1 21933 53086423678 6 +53167 40 1209506 216910 2 0 1 0 -1 44039 53086366390 10 +31106 59 954053 83106 39 0 0 0 -1 44096 53085930569 19 +7359 -1 1209531 52228 33 1 0 0 3190 44097 53085811152 9 +57460 91 36211 54195 135 0 1 1 1059 44098 53085060377 9 +26948 360 1209532 80933 4 0 0 0 63 11424 53086463275 21 +31412 -1 1209533 135891 128 1 0 0 -1 44099 53081206772 34 diff --git a/examples/run_classification_criteo.py b/examples/run_classification_criteo.py index 881fdfbb..67fb3d9a 100644 --- a/examples/run_classification_criteo.py +++ b/examples/run_classification_criteo.py @@ -27,7 +27,7 @@ # 2.count #unique features for each sparse field,and record dense feature field name - fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique()) + fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=data[feat].max() + 1, embedding_dim=4) for feat in sparse_features] + [DenseFeat(feat, 1, ) for feat in dense_features] diff --git a/examples/run_multitask_learning.py b/examples/run_multitask_learning.py new file mode 100644 index 00000000..567037a5 --- /dev/null +++ b/examples/run_multitask_learning.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +import pandas as pd +import torch +from sklearn.metrics import log_loss, roc_auc_score +from sklearn.preprocessing import LabelEncoder, MinMaxScaler + +from deepctr_torch.inputs import SparseFeat, DenseFeat, get_feature_names +from deepctr_torch.models import * + +if __name__ == "__main__": + # data description can be found in https://www.biendata.xyz/competition/icmechallenge2019/ + data = pd.read_csv('./byterec_sample.txt', sep='\t', + names=["uid", "user_city", "item_id", "author_id", "item_city", "channel", "finish", "like", + "music_id", "device", "time", "duration_time"]) + + sparse_features = ["uid", "user_city", "item_id", "author_id", "item_city", "channel", "music_id", "device"] + dense_features = ["duration_time"] + + target = ['finish', 'like'] + + # 1.Label Encoding for sparse features,and do simple Transformation for dense features + for feat in sparse_features: + lbe = LabelEncoder() + data[feat] = lbe.fit_transform(data[feat]) + mms = MinMaxScaler(feature_range=(0, 1)) + data[dense_features] = mms.fit_transform(data[dense_features]) + + # 2.count #unique features for each sparse field,and record dense feature field name + + fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=data[feat].max() + 1, embedding_dim=4) + for feat in sparse_features] + [DenseFeat(feat, 1, ) + for feat in dense_features] + + dnn_feature_columns = fixlen_feature_columns + linear_feature_columns = fixlen_feature_columns + + feature_names = get_feature_names( + linear_feature_columns + dnn_feature_columns) + + # 3.generate input data for model + + split_boundary = int(data.shape[0] * 0.8) + train, test = data[:split_boundary], data[split_boundary:] + train_model_input = {name: train[name] for name in feature_names} + test_model_input = {name: test[name] for name in feature_names} + + # 4.Define Model,train,predict and evaluate + device = 'cpu' + use_cuda = True + if use_cuda and torch.cuda.is_available(): + print('cuda ready...') + device = 'cuda:0' + + model = MMOE(dnn_feature_columns, task_types=['binary', 'binary'], + l2_reg_embedding=1e-5, task_names=target, device=device) + model.compile("adagrad", loss=["binary_crossentropy", "binary_crossentropy"], + metrics=['binary_crossentropy'], ) + + history = model.fit(train_model_input, train[target].values, batch_size=32, epochs=10, verbose=2) + pred_ans = model.predict(test_model_input, 256) + print("") + for i, target_name in enumerate(target): + print("%s test LogLoss" % target_name, round(log_loss(test[target[i]].values, pred_ans[:, i]), 4)) + print("%s test AUC" % target_name, round(roc_auc_score(test[target[i]].values, pred_ans[:, i]), 4)) diff --git a/tests/models/AFN_test.py b/tests/models/AFN_test.py index dce5b207..b7f9ef0a 100644 --- a/tests/models/AFN_test.py +++ b/tests/models/AFN_test.py @@ -7,9 +7,9 @@ @pytest.mark.parametrize( 'afn_dnn_hidden_units, sparse_feature_num, dense_feature_num', - [((256, 128), 3, 0), - ((256, 128), 3, 3), - ((256, 128), 0, 3)] + [((32, 16), 3, 0), + ((32, 16), 3, 3), + ((32, 16), 0, 3)] ) def test_AFN(afn_dnn_hidden_units, sparse_feature_num, dense_feature_num): model_name = 'AFN' diff --git a/tests/models/multitask/ESMM_test.py b/tests/models/multitask/ESMM_test.py new file mode 100644 index 00000000..a091f791 --- /dev/null +++ b/tests/models/multitask/ESMM_test.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import ESMM +from ...utils_mtl import get_mtl_test_data, SAMPLE_SIZE, check_mtl_model, get_device + + +@pytest.mark.parametrize( + 'num_experts, tower_dnn_hidden_units, task_types, sparse_feature_num, dense_feature_num', + [ + (3, (32, 16), ['binary', 'binary'], 3, 3) + ] +) +def test_ESMM(num_experts, tower_dnn_hidden_units, task_types, + sparse_feature_num, dense_feature_num): + model_name = "ESMM" + sample_size = SAMPLE_SIZE + x, y_list, feature_columns = get_mtl_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num) + + model = ESMM(feature_columns, tower_dnn_hidden_units=tower_dnn_hidden_units, + task_types=task_types, device=get_device()) + check_mtl_model(model, model_name, x, y_list, task_types) + + +if __name__ == "__main__": + pass diff --git a/tests/models/multitask/MMOE_test.py b/tests/models/multitask/MMOE_test.py new file mode 100644 index 00000000..a37fe29c --- /dev/null +++ b/tests/models/multitask/MMOE_test.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import MMOE +from ...utils_mtl import get_mtl_test_data, SAMPLE_SIZE, check_mtl_model, get_device + + +@pytest.mark.parametrize( + 'num_experts, expert_dnn_hidden_units, gate_dnn_hidden_units, tower_dnn_hidden_units, task_types, ' + 'sparse_feature_num, dense_feature_num', + [ + (3, (32, 16), (64,), (64,), ['binary', 'binary'], 3, 3), + (3, (32, 16), (), (64,), ['binary', 'binary'], 3, 3), + (3, (32, 16), (64,), (), ['binary', 'binary'], 3, 3), + (3, (32, 16), (), (), ['binary', 'binary'], 3, 3), + (3, (32, 16), (64,), (64,), ['binary', 'regression'], 3, 3), + ] +) +def test_MMOE(num_experts, expert_dnn_hidden_units, gate_dnn_hidden_units, tower_dnn_hidden_units, task_types, + sparse_feature_num, dense_feature_num): + model_name = "MMOE" + sample_size = SAMPLE_SIZE + x, y_list, feature_columns = get_mtl_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num) + + model = MMOE(feature_columns, num_experts=num_experts, expert_dnn_hidden_units=expert_dnn_hidden_units, + gate_dnn_hidden_units=gate_dnn_hidden_units, tower_dnn_hidden_units=tower_dnn_hidden_units, + task_types=task_types, device=get_device()) + check_mtl_model(model, model_name, x, y_list, task_types) + + +if __name__ == "__main__": + pass diff --git a/tests/models/multitask/PLE_test.py b/tests/models/multitask/PLE_test.py new file mode 100644 index 00000000..ca8561f1 --- /dev/null +++ b/tests/models/multitask/PLE_test.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import PLE +from ...utils_mtl import get_mtl_test_data, SAMPLE_SIZE, check_mtl_model, get_device + + +@pytest.mark.parametrize( + 'shared_expert_num, specific_expert_num, num_levels, expert_dnn_hidden_units, gate_dnn_hidden_units, ' + 'tower_dnn_hidden_units, task_types, sparse_feature_num ,dense_feature_num', + [ + (1, 1, 2, (32, 16), (64,), (64,), ['binary', 'binary'], 3, 3), + (3, 3, 3, (32, 16), (), (64,), ['binary', 'binary'], 3, 3), + (3, 3, 3, (32, 16), (64,), (), ['binary', 'binary'], 3, 3), + (3, 3, 3, (32, 16), (), (), ['binary', 'binary'], 3, 3), + (3, 3, 3, (32, 16), (64,), (64,), ['binary', 'regression'], 3, 3), + ] +) +def test_PLE(shared_expert_num, specific_expert_num, num_levels, expert_dnn_hidden_units, gate_dnn_hidden_units, + tower_dnn_hidden_units, task_types, sparse_feature_num, dense_feature_num): + model_name = "PLE" + sample_size = SAMPLE_SIZE + x, y_list, feature_columns = get_mtl_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num) + + model = PLE(feature_columns, shared_expert_num=shared_expert_num, specific_expert_num=specific_expert_num, + num_levels=num_levels, expert_dnn_hidden_units=expert_dnn_hidden_units, + gate_dnn_hidden_units=gate_dnn_hidden_units, tower_dnn_hidden_units=tower_dnn_hidden_units, + task_types=task_types, device=get_device()) + check_mtl_model(model, model_name, x, y_list, task_types) + + +if __name__ == "__main__": + pass diff --git a/tests/models/multitask/SharedBottom_test.py b/tests/models/multitask/SharedBottom_test.py new file mode 100644 index 00000000..f3341f6c --- /dev/null +++ b/tests/models/multitask/SharedBottom_test.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import SharedBottom +from ...utils_mtl import get_mtl_test_data, SAMPLE_SIZE, check_mtl_model, get_device + + +@pytest.mark.parametrize( + 'num_experts, bottom_dnn_hidden_units, tower_dnn_hidden_units, task_types, sparse_feature_num, dense_feature_num', + [ + (3, (32, 16), (64,), ['binary', 'binary'], 3, 3), + (3, (32, 16), (), ['binary', 'binary'], 3, 3), + (3, (32, 16), (64,), ['binary', 'regression'], 3, 3), + ] +) +def test_SharedBottom(num_experts, bottom_dnn_hidden_units, tower_dnn_hidden_units, task_types, + sparse_feature_num, dense_feature_num): + model_name = "SharedBottom" + sample_size = SAMPLE_SIZE + x, y_list, feature_columns = get_mtl_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num) + + model = SharedBottom(feature_columns, bottom_dnn_hidden_units=bottom_dnn_hidden_units, + tower_dnn_hidden_units=tower_dnn_hidden_units, + task_types=task_types, device=get_device()) + check_mtl_model(model, model_name, x, y_list, task_types) + + +if __name__ == "__main__": + pass diff --git a/tests/models/multitask/__init__.py b/tests/models/multitask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils_mtl.py b/tests/utils_mtl.py new file mode 100644 index 00000000..61020cf1 --- /dev/null +++ b/tests/utils_mtl.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +import os + +import numpy as np +import torch as torch + +from deepctr_torch.callbacks import EarlyStopping, ModelCheckpoint +from deepctr_torch.inputs import SparseFeat, DenseFeat, VarLenSparseFeat + +SAMPLE_SIZE = 64 + + +def gen_sequence(dim, max_len, sample_size): + return np.array([np.random.randint(0, dim, max_len) for _ in range(sample_size)]), np.random.randint(1, max_len + 1, + sample_size) + + +def get_mtl_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dense_feature_num=1, + sequence_feature=['sum', 'mean', 'max'], include_length=False, task_types=('binary', 'binary'), + hash_flag=False, prefix=''): + feature_columns = [] + model_input = {} + + if 'weight' in sequence_feature: + feature_columns.append( + VarLenSparseFeat(SparseFeat(prefix + "weighted_seq", vocabulary_size=2, embedding_dim=embedding_size), + maxlen=3, length_name=prefix + "weighted_seq" + "_seq_length", + weight_name=prefix + "weight")) + s_input, s_len_input = gen_sequence( + 2, 3, sample_size) + + model_input[prefix + "weighted_seq"] = s_input + model_input[prefix + 'weight'] = np.random.randn(sample_size, 3, 1) + model_input[prefix + "weighted_seq" + "_seq_length"] = s_len_input + sequence_feature.pop(sequence_feature.index('weight')) + + for i in range(sparse_feature_num): + dim = np.random.randint(1, 10) + feature_columns.append(SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, dtype=torch.int32)) + for i in range(dense_feature_num): + feature_columns.append(DenseFeat(prefix + 'dense_feature_' + str(i), 1, dtype=torch.float32)) + for i, mode in enumerate(sequence_feature): + dim = np.random.randint(1, 10) + maxlen = np.random.randint(1, 10) + feature_columns.append( + VarLenSparseFeat(SparseFeat(prefix + 'sequence_' + mode, vocabulary_size=dim, embedding_dim=embedding_size), + maxlen=maxlen, combiner=mode)) + + for fc in feature_columns: + if isinstance(fc, SparseFeat): + model_input[fc.name] = np.random.randint(0, fc.vocabulary_size, sample_size) + elif isinstance(fc, DenseFeat): + model_input[fc.name] = np.random.random(sample_size) + else: + s_input, s_len_input = gen_sequence( + fc.vocabulary_size, fc.maxlen, sample_size) + model_input[fc.name] = s_input + if include_length: + fc.length_name = prefix + "sequence_" + str(i) + '_seq_length' + model_input[prefix + "sequence_" + str(i) + '_seq_length'] = s_len_input + + y_list = [] # multi label + for task in task_types: + if task == 'binary': + y = np.random.randint(0, 2, sample_size) + y_list.append(y) + else: + y = np.random.random(sample_size) + y_list.append(y) + y_list = np.array(y_list).transpose() # (sample_size, num_tasks) + + return model_input, y_list, feature_columns + + +def check_mtl_model(model, model_name, x, y_list, task_types, check_model_io=True): + ''' + compile model,train and evaluate it,then save/load weight and model file. + :param model: + :param model_name: + :param x: + :param y_list: mutil label of y + :param task_types: + :param check_model_io: + :return: + ''' + loss_list = [] + for task_type in task_types: + if task_type == 'binary': + loss_list.append('binary_crossentropy') + elif task_type == 'regression': + loss_list.append('mae') + print('loss:', loss_list) + + early_stopping = EarlyStopping(monitor='val_acc', min_delta=0, verbose=1, patience=0, mode='max') + model_checkpoint = ModelCheckpoint(filepath='model.ckpt', monitor='val_acc', verbose=1, + save_best_only=True, + save_weights_only=False, mode='max', period=1) + + model.compile('adam', loss_list, metrics=['binary_crossentropy', 'acc']) + model.fit(x, y_list, batch_size=100, epochs=1, validation_split=0.5, callbacks=[early_stopping, model_checkpoint]) + + print(model_name + 'test, train valid pass!') + torch.save(model.state_dict(), model_name + '_weights.h5') + model.load_state_dict(torch.load(model_name + '_weights.h5')) + os.remove(model_name + '_weights.h5') + print(model_name + 'test save load weight pass!') + if check_model_io: + torch.save(model, model_name + '.h5') + model = torch.load(model_name + '.h5') + os.remove(model_name + '.h5') + print(model_name + 'test save load model pass!') + print(model_name + 'test pass!') + + +def get_device(use_cuda=True): + device = 'cpu' + if use_cuda and torch.cuda.is_available(): + print('cuda ready...') + device = 'cuda:0' + return device