forked from shenweichen/DeepCTR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
浅梦
authored
Apr 21, 2019
1 parent
715c25e
commit 33faa4b
Showing
35 changed files
with
285 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star | |
|
||
| Model | Paper | | ||
| :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| Convolutional Click Prediction Model | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) | | ||
| Factorization-supported Neural Network | [ECIR 2016][Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction](https://arxiv.org/pdf/1601.02376.pdf) | | ||
| Product-based Neural Network | [ICDM 2016][Product-based neural networks for user response prediction](https://arxiv.org/pdf/1611.00144.pdf) | | ||
| Wide & Deep | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792.pdf) | | ||
|
@@ -31,9 +32,9 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star | |
| Deep & Cross Network | [ADKDD 2017][Deep & Cross Network for Ad Click Predictions](https://arxiv.org/abs/1708.05123) | | ||
| Attentional Factorization Machine | [IJCAI 2017][Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks](http://www.ijcai.org/proceedings/2017/435) | | ||
| Neural Factorization Machine | [SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf) | | ||
| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) | | ||
| Deep Interest Evolution Network | [arxiv 2018][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) | | ||
| xDeepFM | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf) | | ||
| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) | | ||
| Deep Interest Evolution Network | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) | | ||
| AutoInt | [arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | | ||
| NFFM | [arxiv 2019][Field-aware Neural Factorization Machine for Click-Through Rate Prediction ](https://arxiv.org/pdf/1902.09096.pdf)(The original NFFM was first used by Yi Yang([email protected]) in TSA competition in 2017.) | | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
from .afm import AFM | ||
from .autoint import AutoInt | ||
from .ccpm import CCPM | ||
from .dcn import DCN | ||
from .mlr import MLR | ||
from .deepfm import DeepFM | ||
from .nfm import NFM | ||
from .din import DIN | ||
from .dien import DIEN | ||
from .din import DIN | ||
from .fnn import FNN | ||
from .mlr import MLR | ||
from .nffm import NFFM | ||
from .nfm import NFM | ||
from .pnn import PNN | ||
from .wdl import WDL | ||
from .xdeepfm import xDeepFM | ||
from .autoint import AutoInt | ||
from .nffm import NFFM | ||
|
||
|
||
__all__ = ["AFM", "DCN", "MLR", "DeepFM", | ||
__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM", | ||
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "NFFM"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# -*- coding:utf-8 -*- | ||
""" | ||
Author: | ||
Weichen Shen,[email protected] | ||
Reference: | ||
[1] Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746. | ||
(http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) | ||
""" | ||
import tensorflow as tf | ||
|
||
from ..input_embedding import preprocess_input_embedding | ||
from ..layers.core import MLP, PredictionLayer | ||
from ..layers.sequence import KMaxPooling | ||
from ..layers.utils import concat_fun | ||
from ..utils import check_feature_config_dict | ||
|
||
|
||
def CCPM(feature_dim_dict, embedding_size=8, conv_kernel_width=(6, 5), conv_filters=(4, 4), hidden_size=(256,), | ||
l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_deep=0, keep_prob=1.0, init_std=0.0001, seed=1024, | ||
final_activation='sigmoid', ): | ||
"""Instantiates the Convolutional Click Prediction Model architecture. | ||
:param feature_dim_dict: dict,to indicate sparse field and dense field like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_4','field_5']} | ||
:param embedding_size: positive integer,sparse feature embedding_size | ||
:param conv_kernel_width: list,list of positive integer or empty list,the width of filter in each conv layer. | ||
:param conv_filters: list,list of positive integer or empty list,the number of filters in each conv layer. | ||
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net. | ||
:param l2_reg_linear: float. L2 regularizer strength applied to linear part | ||
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector | ||
:param l2_reg_deep: float. L2 regularizer strength applied to deep net | ||
:param keep_prob: float in (0,1]. keep_prob after attention net | ||
:param init_std: float,to use as the initialize std of embedding vector | ||
:param seed: integer ,to use as random seed. | ||
:param final_activation: str,output activation,usually ``'sigmoid'`` or ``'linear'`` | ||
:return: A Keras model instance. | ||
""" | ||
|
||
check_feature_config_dict(feature_dim_dict) | ||
|
||
deep_emb_list, linear_logit, inputs_list = preprocess_input_embedding(feature_dim_dict, embedding_size, | ||
l2_reg_embedding, l2_reg_linear, init_std, | ||
seed, True) | ||
n = len(deep_emb_list) | ||
l = len(conv_filters) | ||
|
||
conv_input = concat_fun(deep_emb_list, axis=1) | ||
pooling_result = tf.keras.layers.Lambda( | ||
lambda x: tf.expand_dims(x, axis=3))(conv_input) | ||
|
||
for i in range(1, l + 1): | ||
filters = conv_filters[i - 1] | ||
width = conv_kernel_width[i - 1] | ||
k = max(1, int((1 - pow(i / l, l - i)) * n)) if i < l else 3 | ||
|
||
conv_result = tf.keras.layers.Conv2D(filters=filters, kernel_size=(width, 1), strides=(1, 1), padding='same', | ||
activation='tanh', use_bias=True, )(pooling_result) | ||
pooling_result = KMaxPooling( | ||
k=min(k, conv_result.shape[1].value), axis=1)(conv_result) | ||
|
||
flatten_result = tf.keras.layers.Flatten()(pooling_result) | ||
final_logit = MLP(hidden_size, l2_reg=l2_reg_deep, | ||
keep_prob=keep_prob)(flatten_result) | ||
final_logit = tf.keras.layers.Dense(1, use_bias=False)(final_logit) | ||
|
||
final_logit = tf.keras.layers.add([final_logit, linear_logit]) | ||
output = PredictionLayer(final_activation)(final_logit) | ||
model = tf.keras.models.Model(inputs=inputs_list, outputs=output) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
Reference: | ||
[1] Zhang L, Shen W, Li S, et al. Field-aware Neural Factorization Machine for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1902.09096, 2019.(https://arxiv.org/abs/1902.09096) | ||
(The original NFFM was first used by Yi Yang([email protected]) in TSA competition in 2017.) | ||
""" | ||
|
||
import itertools | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.