forked from tensorflow/addons
-
Notifications
You must be signed in to change notification settings - Fork 0
/
source_code_test.py
228 lines (197 loc) · 9.41 KB
/
source_code_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
import glob
import os
from typedapi import ensure_api_is_typed
import tensorflow_addons as tfa
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
def test_api_typed():
modules_list = [
tfa,
tfa.activations,
tfa.callbacks,
tfa.image,
tfa.losses,
tfa.metrics,
tfa.optimizers,
tfa.rnn,
tfa.seq2seq,
tfa.text,
]
# Files within this list will be exempt from verification.
exception_list = [
tfa.rnn.PeepholeLSTMCell,
]
help_message = (
"You can also take a look at the section about it in the CONTRIBUTING.md:\n"
"https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#about-type-hints"
)
ensure_api_is_typed(
modules_list, exception_list, init_only=True, additional_message=help_message
)
def test_case_insensitive_filesystems():
# Make sure BASE_DIR is project root.
# If it doesn't, we probably computed the wrong directory.
if not os.path.isdir(os.path.join(BASE_DIR, "tensorflow_addons")):
raise AssertionError("BASE_DIR = {} is not project root".format(BASE_DIR))
for dirpath, dirnames, filenames in os.walk(BASE_DIR, followlinks=True):
lowercase_directories = [x.lower() for x in dirnames]
lowercase_files = [x.lower() for x in filenames]
lowercase_dir_contents = lowercase_directories + lowercase_files
if len(lowercase_dir_contents) != len(set(lowercase_dir_contents)):
raise AssertionError(
"Files with same name but different case detected "
"in directory: {}".format(dirpath)
)
def get_lines_of_source_code(allowlist=None):
allowlist = allowlist or []
source_dir = os.path.join(BASE_DIR, "tensorflow_addons")
for path in glob.glob(source_dir + "/**/*.py", recursive=True):
if in_allowlist(path, allowlist):
continue
with open(path) as f:
for line_idx, line in enumerate(f):
yield path, line_idx, line
def in_allowlist(file_path, allowlist):
for allowed_file in allowlist:
if file_path.endswith(allowed_file):
return True
return False
def test_no_private_tf_api():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/metrics/r_square.py",
"tensorflow_addons/utils/test_utils.py",
"tensorflow_addons/seq2seq/decoder.py",
"tensorflow_addons/seq2seq/attention_wrapper.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
if "import tensorflow.python" in line or "from tensorflow.python" in line:
raise ImportError(
"A private tensorflow API import was found in {} at line {}.\n"
"tensorflow.python refers to TensorFlow's internal source "
"code and private functions/classes.\n"
"The use of those is forbidden in Addons for stability reasons."
"\nYou should find a public alternative or ask the "
"TensorFlow team to expose publicly the function/class "
"that you are using.\n"
"If you're trying to do `import tensorflow.python.keras` "
"it can be replaced with `import tensorflow.keras`."
"".format(file_path, line_idx + 1)
)
def test_no_tf_cond():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/text/crf.py",
"tensorflow_addons/layers/wrappers.py",
"tensorflow_addons/image/connected_components.py",
"tensorflow_addons/optimizers/novograd.py",
"tensorflow_addons/metrics/cohens_kappa.py",
"tensorflow_addons/seq2seq/sampler.py",
"tensorflow_addons/seq2seq/beam_search_decoder.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
if "tf.cond(" in line:
raise NameError(
"The usage of a tf.cond() function call was found in "
"file {} at line {}:\n\n"
" {}\n"
"In TensorFlow 2.x, using a simple `if` in a function decorated "
"with `@tf.function` is equivalent to a tf.cond() thanks to Autograph. \n"
"TensorFlow Addons aims to be written with idiomatic TF 2.x code. \n"
"As such, using tf.cond() is not allowed in the codebase. \n"
"Use a `if` and decorate your function with @tf.function instead. \n"
"You can take a look at "
"https://www.tensorflow.org/guide/function#use_python_control_flow"
"".format(file_path, line_idx, line)
)
def test_no_experimental_api():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/optimizers/weight_decay_optimizers.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
if file_path.endswith("_test.py") or file_path.endswith("conftest.py"):
continue
if file_path.endswith("tensorflow_addons/utils/test_utils.py"):
continue
if "experimental" in line:
raise NameError(
"The usage of a TensorFlow experimental API was found in file {} "
"at line {}:\n\n"
" {}\n"
"Experimental APIs are ok in tests but not in user-facing code. "
"This is because Experimental APIs might have bugs and are not "
"widely used yet.\n"
"Addons should show how to write TensorFlow "
"code in a stable and forward-compatible way."
"".format(file_path, line_idx, line)
)
def test_no_tf_control_dependencies():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/layers/wrappers.py",
"tensorflow_addons/image/utils.py",
"tensorflow_addons/image/dense_image_warp.py",
"tensorflow_addons/optimizers/average_wrapper.py",
"tensorflow_addons/optimizers/yogi.py",
"tensorflow_addons/optimizers/lookahead.py",
"tensorflow_addons/optimizers/weight_decay_optimizers.py",
"tensorflow_addons/optimizers/rectified_adam.py",
"tensorflow_addons/optimizers/lamb.py",
"tensorflow_addons/seq2seq/sampler.py",
"tensorflow_addons/seq2seq/beam_search_decoder.py",
"tensorflow_addons/seq2seq/attention_wrapper.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
if "tf.control_dependencies(" in line:
raise NameError(
"The usage of a tf.control_dependencies() function call was found in "
"file {} at line {}:\n\n"
" {}\n"
"In TensorFlow 2.x, in a function decorated "
"with `@tf.function` the dependencies are controlled automatically"
" thanks to Autograph. \n"
"TensorFlow Addons aims to be written with idiomatic TF 2.x code. \n"
"As such, using tf.control_dependencies() is not allowed in the codebase. \n"
"Decorate your function with @tf.function instead. \n"
"You can take a look at \n"
"https://github.com/tensorflow/community/blob/master/rfcs/20180918-functions-not-sessions-20.md#program-order-semantics--control-dependencies"
"".format(file_path, line_idx, line)
)
def test_no_deprecated_v1():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/text/skip_gram_ops.py",
"tensorflow_addons/seq2seq/decoder.py",
"tensorflow_addons/seq2seq/tests/attention_wrapper_test.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
if "tf.compat.v1" in line:
raise NameError(
"The usage of a tf.compat.v1 API was found in file {} at line {}:\n\n"
" {}\n"
"TensorFlow Addons doesn't support running programs with "
"`tf.compat.v1.disable_v2_behavior()`.\n"
"As such, there should be no need for the compatibility module "
"tf.compat. Please find an alternative using only the TF2.x API."
"".format(file_path, line_idx, line)
)