-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
101 lines (79 loc) · 2.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
"""
@author: kebo
@contact: [email protected]
@version: 1.0
@file: model.py
@time: 2020/6/3 下午11:45
这一行开始写关于本文件的说明与解释
"""
from logging import getLogger
import tensorflow as tf
from utils import dot
logger = getLogger(__name__)
class GraphConvolution(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, num_features_nonzero,
dropout=0,
is_sparse_inputs=False,
activation=tf.nn.relu,
use_bias=False,
featureless=False, **kwargs):
super(GraphConvolution).__init__(**kwargs)
self.dropout = dropout
self.activation = activation
self.is_sparse_inputs = is_sparse_inputs
self.featureless = featureless
self.use_bias = use_bias
self.num_features_nonzero = num_features_nonzero
self.weights_ = []
for i in range(1):
w = self.add_variable('weight' + str(i), [input_dim, output_dim])
self.weights_.append(w)
if self.use_bias:
self.bias = self.add_variable('bias', [output_dim])
# for p in self.trainable_variables:
# print(p.name, p.shape)
def call(self, inputs, training=False, **kwargs):
x, support_ = inputs
# dropout
if training and self.is_sparse_inputs:
x = self.sparse_dropout(x, 1 - self.dropout, self.num_features_nonzero)
else:
x = tf.nn.dropout(x, 1 - self.dropout)
# convolve
supports = list()
for i in range(len(support_)):
if not self.featureless: # if it has features x
pre_sup = dot(x, self.weights_[i], spares=self.is_sparse_inputs)
else:
pre_sup = self.weights_[i]
support = dot(support_[i], pre_sup, spares=True)
supports.append(support)
output = tf.add_n(supports)
# bias
if self.bias:
output += self.bias
return self.activation(output)
@classmethod
def sparse_dropout(cls, x, rate, noise_shape):
"""
Dropout for sparse tensors.
"""
random_tensor = 1 - rate
random_tensor += tf.random.uniform(noise_shape)
dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
pre_out = tf.sparse.retain(x, dropout_mask)
return pre_out * (1. / (1 - rate))
class GCN(tf.keras.models.Model):
def __init__(self, input_dim, output_dim, num_features_nonzero, **kwargs):
super(GCN).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.num_features_nonzero = num_features_nonzero
logger.info('input dim: ', input_dim)
logger.info('output dim:', output_dim)
logger.info('num_features_nonzero: ', num_features_nonzero)
self.layers_ = []
self.layers_.append()
def call(self, inputs, training=None, mask=None):
x, label, mask, support = inputs