Skip to content

Commit

Permalink
feat: add CCPM&KMaxPooling
Browse files Browse the repository at this point in the history
  • Loading branch information
浅梦 authored Apr 21, 2019
1 parent 715c25e commit 33faa4b
Show file tree
Hide file tree
Showing 35 changed files with 285 additions and 75 deletions.
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ A clear and concise description of what the question is.
Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.4, 3.6]
- tensorflow version [e.g. 1.4.0, 1.12.0]
- deepctr version [e.g. 0.2.3,]
- python version [e.g. 3.6]
- tensorflow version [e.g. 1.4.0,]
- deepctr version [e.g. 0.3.2,]
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ python:

env:
# - TF_VERSION=1.13.1
- TF_VERSION=1.12.0
# - TF_VERSION=1.12.2
- TF_VERSION=1.4.0
#Not Support- TF_VERSION=1.7.0
#Not Support- TF_VERSION=1.7.1
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star

| Model | Paper |
| :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Convolutional Click Prediction Model | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) |
| Factorization-supported Neural Network | [ECIR 2016][Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction](https://arxiv.org/pdf/1601.02376.pdf) |
| Product-based Neural Network | [ICDM 2016][Product-based neural networks for user response prediction](https://arxiv.org/pdf/1611.00144.pdf) |
| Wide & Deep | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792.pdf) |
Expand All @@ -31,9 +32,9 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star
| Deep & Cross Network | [ADKDD 2017][Deep & Cross Network for Ad Click Predictions](https://arxiv.org/abs/1708.05123) |
| Attentional Factorization Machine | [IJCAI 2017][Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks](http://www.ijcai.org/proceedings/2017/435) |
| Neural Factorization Machine | [SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf) |
| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) |
| Deep Interest Evolution Network | [arxiv 2018][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) |
| xDeepFM | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf) |
| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) |
| Deep Interest Evolution Network | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) |
| AutoInt | [arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) |
| NFFM | [arxiv 2019][Field-aware Neural Factorization Machine for Click-Through Rate Prediction ](https://arxiv.org/pdf/1902.09096.pdf)(The original NFFM was first used by Yi Yang([email protected]) in TSA competition in 2017.) |

Expand Down
2 changes: 1 addition & 1 deletion deepctr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version, SingleFeat, VarLenFeat

__version__ = '0.3.2'
__version__ = '0.3.3'
check_version(__version__)
3 changes: 2 additions & 1 deletion deepctr/input_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def create_varlenfeat_inputdict(feature_dim_dict, mask_zero=True):

def create_embedding_dict(feature_dim_dict, embedding_size, init_std, seed, l2_reg, prefix='sparse', seq_mask_zero=True):
if embedding_size == 'auto':

print("Notice:Do not use auto embedding in models other than DCN")
sparse_embedding = {feat.name: Embedding(feat.dimension, 6 * int(pow(feat.dimension, 0.25)),
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
Expand Down Expand Up @@ -99,6 +99,7 @@ def merge_dense_input(dense_input_, embed_list, embedding_size, l2_reg):
dense_input = list(dense_input_.values())
if len(dense_input) > 0:
if embedding_size == "auto":
print("Notice:Do not use auto embedding in models other than DCN")
if len(dense_input) == 1:
continuous_embedding_list = dense_input[0]
else:
Expand Down
22 changes: 15 additions & 7 deletions deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from .core import LocalActivationUnit,MLP,PredictionLayer
from .interaction import AFMLayer,BiInteractionPooling,CIN,CrossNet,FM,InnerProductLayer,InteractingLayer,OutterProductLayer
from .normalization import LayerNormalization
import tensorflow as tf

from .activation import Dice
from .sequence import SequencePoolingLayer,AttentionSequencePoolingLayer,BiLSTM,Transformer,Position_Embedding,BiasEncoding
from .core import MLP, LocalActivationUnit, PredictionLayer
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
InnerProductLayer, InteractingLayer,
OutterProductLayer)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, Position_Embedding, SequencePoolingLayer,
Transformer)
from .utils import NoMask

custom_objects = {'InnerProductLayer': InnerProductLayer,
custom_objects = {'tf': tf,
'InnerProductLayer': InnerProductLayer,
'OutterProductLayer': OutterProductLayer,
'MLP': MLP,
'PredictionLayer': PredictionLayer,
Expand All @@ -22,5 +29,6 @@
'LayerNormalization': LayerNormalization,
'BiLSTM': BiLSTM,
'Transformer': Transformer,
'NoMask':NoMask,
'BiasEncoding':BiasEncoding}
'NoMask': NoMask,
'BiasEncoding': BiasEncoding,
'KMaxPooling': KMaxPooling}
60 changes: 59 additions & 1 deletion deepctr/layers/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from ..contrib.rnn import dynamic_rnn
from ..contrib.utils import QAAttGRUCell, VecAttGRUCell

from .core import LocalActivationUnit
from .normalization import LayerNormalization

Expand Down Expand Up @@ -690,3 +689,62 @@ def compute_output_shape(self, input_shape):
return rnn_input_shape
else:
return (None, 1, rnn_input_shape[2])


class KMaxPooling(Layer):
"""K Max pooling that selects the k biggest value along the specific axis.
Input shape
- nD tensor with shape: ``(batch_size, ..., input_dim)``.
Output shape
- nD tensor with shape: ``(batch_size, ..., output_dim)``.
Arguments
- **k**: positive integer, number of top elements to look for along the ``axis`` dimension.
- **axis**: positive integer, the dimension to look for elements.
"""

def __init__(self, k=1, axis=-1, **kwargs):

self.k = k
self.axis = axis
super(KMaxPooling, self).__init__(**kwargs)

def build(self, input_shape):

if self.axis < 1 or self.axis > len(input_shape):
raise ValueError("axis must be 1~%d,now is %d" %
(len(input_shape), len(input_shape)))

if self.k < 1 or self.k > input_shape[self.axis]:
raise ValueError("k must be in 1 ~ %d,now k is %d" %
(input_shape[self.axis], self.k))
self.dims = len(input_shape)
# Be sure to call this somewhere!
super(KMaxPooling, self).build(input_shape)

def call(self, inputs):

# swap the last and the axis dimensions since top_k will be applied along the last dimension
perm = list(range(self.dims))
perm[-1], perm[self.axis] = perm[self.axis], perm[-1]
shifted_input = tf.transpose(inputs, perm)

# extract top_k, returns two tensors [values, indices]
top_k = tf.nn.top_k(shifted_input, k=self.k, sorted=True, name=None)[0]
output = tf.transpose(top_k, perm)

return output

def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[self.axis] = self.k
return tuple(output_shape)

def get_config(self,):
config = {'k': self.k, 'axis': self.axis}
base_config = super(KMaxPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
14 changes: 7 additions & 7 deletions deepctr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from .afm import AFM
from .autoint import AutoInt
from .ccpm import CCPM
from .dcn import DCN
from .mlr import MLR
from .deepfm import DeepFM
from .nfm import NFM
from .din import DIN
from .dien import DIEN
from .din import DIN
from .fnn import FNN
from .mlr import MLR
from .nffm import NFFM
from .nfm import NFM
from .pnn import PNN
from .wdl import WDL
from .xdeepfm import xDeepFM
from .autoint import AutoInt
from .nffm import NFFM


__all__ = ["AFM", "DCN", "MLR", "DeepFM",
__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM",
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "NFFM"]
71 changes: 71 additions & 0 deletions deepctr/models/ccpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,[email protected]
Reference:
[1] Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746.
(http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf)
"""
import tensorflow as tf

from ..input_embedding import preprocess_input_embedding
from ..layers.core import MLP, PredictionLayer
from ..layers.sequence import KMaxPooling
from ..layers.utils import concat_fun
from ..utils import check_feature_config_dict


def CCPM(feature_dim_dict, embedding_size=8, conv_kernel_width=(6, 5), conv_filters=(4, 4), hidden_size=(256,),
l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_deep=0, keep_prob=1.0, init_std=0.0001, seed=1024,
final_activation='sigmoid', ):
"""Instantiates the Convolutional Click Prediction Model architecture.
:param feature_dim_dict: dict,to indicate sparse field and dense field like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_4','field_5']}
:param embedding_size: positive integer,sparse feature embedding_size
:param conv_kernel_width: list,list of positive integer or empty list,the width of filter in each conv layer.
:param conv_filters: list,list of positive integer or empty list,the number of filters in each conv layer.
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net.
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param l2_reg_deep: float. L2 regularizer strength applied to deep net
:param keep_prob: float in (0,1]. keep_prob after attention net
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:param final_activation: str,output activation,usually ``'sigmoid'`` or ``'linear'``
:return: A Keras model instance.
"""

check_feature_config_dict(feature_dim_dict)

deep_emb_list, linear_logit, inputs_list = preprocess_input_embedding(feature_dim_dict, embedding_size,
l2_reg_embedding, l2_reg_linear, init_std,
seed, True)
n = len(deep_emb_list)
l = len(conv_filters)

conv_input = concat_fun(deep_emb_list, axis=1)
pooling_result = tf.keras.layers.Lambda(
lambda x: tf.expand_dims(x, axis=3))(conv_input)

for i in range(1, l + 1):
filters = conv_filters[i - 1]
width = conv_kernel_width[i - 1]
k = max(1, int((1 - pow(i / l, l - i)) * n)) if i < l else 3

conv_result = tf.keras.layers.Conv2D(filters=filters, kernel_size=(width, 1), strides=(1, 1), padding='same',
activation='tanh', use_bias=True, )(pooling_result)
pooling_result = KMaxPooling(
k=min(k, conv_result.shape[1].value), axis=1)(conv_result)

flatten_result = tf.keras.layers.Flatten()(pooling_result)
final_logit = MLP(hidden_size, l2_reg=l2_reg_deep,
keep_prob=keep_prob)(flatten_result)
final_logit = tf.keras.layers.Dense(1, use_bias=False)(final_logit)

final_logit = tf.keras.layers.add([final_logit, linear_logit])
output = PredictionLayer(final_activation)(final_logit)
model = tf.keras.models.Model(inputs=inputs_list, outputs=output)
return model
8 changes: 4 additions & 4 deletions deepctr/models/dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ..input_embedding import preprocess_input_embedding
from ..layers.core import PredictionLayer, MLP
from ..layers.interaction import CrossNet
from ..utils import check_feature_config_dict
from ..layers.utils import concat_fun
from ..utils import check_feature_config_dict


def DCN(feature_dim_dict, embedding_size='auto',
cross_num=2, hidden_size=(128, 128, ), l2_reg_embedding=1e-5, l2_reg_cross=1e-5, l2_reg_deep=0,
cross_num=2, hidden_size=(128, 128,), l2_reg_embedding=1e-5, l2_reg_cross=1e-5, l2_reg_deep=0,
init_std=0.0001, seed=1024, keep_prob=1, use_bn=False, activation='relu', final_activation='sigmoid',
):
"""Instantiates the Deep&Cross Network architecture.
Expand All @@ -43,8 +43,8 @@ def DCN(feature_dim_dict, embedding_size='auto',
check_feature_config_dict(feature_dim_dict)

deep_emb_list, _, inputs_list = preprocess_input_embedding(feature_dim_dict, embedding_size,
l2_reg_embedding, 0, init_std,
seed, False)
l2_reg_embedding, 0, init_std,
seed, False)

deep_input = tf.keras.layers.Flatten()(concat_fun(deep_emb_list))

Expand Down
1 change: 1 addition & 0 deletions deepctr/models/nffm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Reference:
[1] Zhang L, Shen W, Li S, et al. Field-aware Neural Factorization Machine for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1902.09096, 2019.(https://arxiv.org/abs/1902.09096)
(The original NFFM was first used by Yi Yang([email protected]) in TSA competition in 2017.)
"""

import itertools
Expand Down
Binary file added docs/pics/CCPM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions docs/source/Features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ Feature Extractor
Models
--------

CCPM (Convolutional Click Prediction Model)
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

CCPM can extract local-global key features from an input instance with varied elements, which can be implemented for not only single ad impression but also sequential ad impression.

**CCPM api** `link <./deepctr.models.ccpm.html>`_

.. image:: ../pics/CCPM.png
:align: center
:scale: 50 %


`Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746. <http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf>`_


FNN (Factorization-supported Neural Network)
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 04/21/2019 : [v0.3.3](https://github.com/shenweichen/DeepCTR/releases/tag/v0.3.3) released.Add [CCPM](./Features.html#ccpm-convolutional-click-prediction-model).
- 03/30/2019 : [v0.3.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.3.2) released.Add [DIEN](./Features.html#dien-deep-interest-evolution-network) and [NFFM](./Features.html#nffm-field-aware-neural-factorization-machine) Model.
- 02/17/2019 : [v0.3.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.3.1) released.Refactor layers ,add `BiLSTM` and `Transformer`.
- 01/24/2019 : [v0.2.3](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.3) released.Use a new feature config generation method and fix bugs.
Expand Down
Loading

0 comments on commit 33faa4b

Please sign in to comment.