-
Notifications
You must be signed in to change notification settings - Fork 706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
新增多任务模型 Dev zsx mtl1 #255
Merged
Merged
新增多任务模型 Dev zsx mtl1 #255
Changes from 38 commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
0671b59
mtl
zanshuxun 5267cf8
inplace operation
zanshuxun 25cf516
1
zanshuxun 567576a
1
zanshuxun 30d0fc7
1
zanshuxun 71cf99a
11
zanshuxun 29f177f
111
zanshuxun cbd9eea
if self.num_tasks <= 1:
zanshuxun c00ff9f
add_regularization_weight
zanshuxun 0c70377
1
zanshuxun 2340721
add_regularization_weight
zanshuxun 1663dbc
1
zanshuxun 64ed8fa
byterec_sample.txt
zanshuxun f370701
format
zanshuxun 9f3a51c
format
zanshuxun a1de4a9
format
zanshuxun 9eee512
完善超参及注释
zanshuxun 4d38beb
cgc pole
zanshuxun 113d90b
ple
zanshuxun ae8b626
ple
zanshuxun 269a305
mtl
zanshuxun 0d0ec88
dim
zanshuxun ee23764
1
zanshuxun 14ec373
1
zanshuxun 310c043
test
zanshuxun 0f56a63
dien lengths .cpu()
zanshuxun 5660a4b
docs
zanshuxun 6e4611d
docs
zanshuxun 4c73e7e
eg
zanshuxun 80552b6
注释
zanshuxun d9d45cb
注释
zanshuxun 924dcf7
format
zanshuxun 1efafb3
format
zanshuxun 18f04ed
byterec_sample.txt 200
zanshuxun 40b7d4c
byterec_sample.txt 200
zanshuxun 5e59953
multi_module_list
zanshuxun 7dff848
format
zanshuxun 1c71508
1
f11a00f
dcnmix
3e20cb8
缩小test维度
2232321
add_regularization_weight
4b054bd
data url; loss_func
a35ee3f
final layer
13de334
format
cdbfb16
format
05f134c
add_regularization_weight
90cae06
add_regularization_weight dcnmix
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .sharedbottom import SharedBottom | ||
from .esmm import ESMM | ||
from .mmoe import MMOE | ||
from .ple import PLE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding:utf-8 -*- | ||
""" | ||
Author: | ||
zanshuxun, [email protected] | ||
|
||
Reference: | ||
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval. 2018.(https://dl.acm.org/doi/10.1145/3209978.3210104) | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
from ..basemodel import BaseModel | ||
from ...inputs import combined_dnn_input | ||
from ...layers import DNN | ||
|
||
|
||
class ESMM(BaseModel): | ||
"""Instantiates the Entire Space Multi-Task Model architecture. | ||
|
||
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model. | ||
:param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN. | ||
:param l2_reg_linear: float, L2 regularizer strength applied to linear part. | ||
:param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector. | ||
:param l2_reg_dnn: float, L2 regularizer strength applied to DNN. | ||
:param init_std: float, to use as the initialize std of embedding vector. | ||
:param seed: integer, to use as random seed. | ||
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. | ||
:param dnn_activation: Activation function to use in DNN. | ||
:param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN. | ||
:param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss or ``"regression"`` for regression loss. e.g. ['binary', 'regression']. | ||
:param task_names: list of str, indicating the predict target of each tasks. | ||
:param device: str, ``"cpu"`` or ``"cuda:0"``. | ||
:param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. | ||
|
||
:return: A PyTorch model instance. | ||
""" | ||
|
||
def __init__(self, dnn_feature_columns, tower_dnn_hidden_units=(256, 128), | ||
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, | ||
dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task_types=('binary', 'binary'), | ||
task_names=('ctr', 'ctcvr'), device='cpu', gpus=None): | ||
super(ESMM, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, | ||
l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, | ||
seed=seed, task='binary', device=device, gpus=gpus) | ||
self.num_tasks = len(task_names) | ||
if self.num_tasks != 2: | ||
raise ValueError("the length of task_names must be equal to 2") | ||
if len(dnn_feature_columns) == 0: | ||
raise ValueError("dnn_feature_columns is null!") | ||
if len(task_types) != self.num_tasks: | ||
raise ValueError("num_tasks must be equal to the length of task_types") | ||
|
||
for task_type in task_types: | ||
if task_type != 'binary': | ||
raise ValueError("task must be binary in ESMM, {} is illegal".format(task_type)) | ||
|
||
input_dim = self.compute_input_dim(dnn_feature_columns) | ||
|
||
self.ctr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation, | ||
dropout_rate=dnn_dropout, use_bn=dnn_use_bn, | ||
init_std=init_std, device=device) | ||
self.cvr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation, | ||
dropout_rate=dnn_dropout, use_bn=dnn_use_bn, | ||
init_std=init_std, device=device) | ||
|
||
self.ctr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False) | ||
self.cvr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False) | ||
|
||
self.add_regularization_weight( | ||
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.ctr_dnn.named_parameters()), l2=l2_reg_dnn) | ||
self.add_regularization_weight( | ||
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.cvr_dnn.named_parameters()), l2=l2_reg_dnn) | ||
self.add_regularization_weight(self.ctr_dnn_final_layer.weight, l2=l2_reg_dnn) | ||
self.add_regularization_weight(self.cvr_dnn_final_layer.weight, l2=l2_reg_dnn) | ||
self.to(device) | ||
|
||
def forward(self, X): | ||
sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, | ||
self.embedding_dict) | ||
dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) | ||
|
||
ctr_output = self.ctr_dnn(dnn_input) | ||
cvr_output = self.cvr_dnn(dnn_input) | ||
|
||
ctr_logit = self.ctr_dnn_final_layer(ctr_output) | ||
cvr_logit = self.cvr_dnn_final_layer(cvr_output) | ||
|
||
ctr_pred = self.out(ctr_logit) | ||
cvr_pred = self.out(cvr_logit) | ||
|
||
ctcvr_pred = ctr_pred * cvr_pred # CTCVR = CTR * CVR | ||
|
||
task_outs = torch.cat([ctr_pred, ctcvr_pred], -1) | ||
return task_outs |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过'_list'来判断参数容易引起后续维护迭代出现问题