-
Notifications
You must be signed in to change notification settings - Fork 42
/
train_mem_saving.py
173 lines (128 loc) · 6.01 KB
/
train_mem_saving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
r"""Train with gradient checkpointing
Allows for training with larger batch size than would normally be
permitted. Used to train PSPNet and ICNet to facilitate the required
batch size of 16. Also facilities measuring memory usage. Thanks to Yaroslav
Bulatov and Tim Salimans for their gradient checkpointing implementation.
Their implementation can be found here:
https://github.com/openai/gradient-checkpointing
For ICNet, the suggested checkpoint nodes are:
SharedFeatureExtractor/resnet_v1_50/block1/unit_3/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block2/unit_4/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block3/unit_6/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block4/unit_3/bottleneck_v1/Relu
FastPSPModule/Conv/Relu
CascadeFeatureFusion/Relu
CascadeFeatureFusion_1/Relu
For PSPNet 50, the suggested checkpoint nodes are:
SharedFeatureExtractor/resnet_v1_50/block1/unit_3/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block2/unit_4/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block3/unit_6/bottleneck_v1/Relu
SharedFeatureExtractor/resnet_v1_50/block4/unit_3/bottleneck_v1/Relu
PSPModule/Conv/Relu
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from google.protobuf import text_format
from third_party.memory_saving_gradients_patch import gradients
from builders import model_builder
from builders import dataset_builder
from protos import pipeline_pb2
from libs.trainer import train_segmentation_model
# CEHCKPOINT NODES FOR MEM SAVING GRADIENTS
ICNET_GRADIENT_CHECKPOINTS = [
'SharedFeatureExtractor/resnet_v1_50/block1/unit_3/bottleneck_v1/Relu',
'SharedFeatureExtractor/resnet_v1_50/block2/unit_4/bottleneck_v1/Relu',
'SharedFeatureExtractor/resnet_v1_50/block3/unit_6/bottleneck_v1/Relu',
'SharedFeatureExtractor/resnet_v1_50/block4/unit_3/bottleneck_v1/Relu',
'FastPSPModule/Conv/Relu6',
'CascadeFeatureFusion/Relu',
'CascadeFeatureFusion_1/Relu'
]
tf.logging.set_verbosity(tf.logging.INFO)
slim = tf.contrib.slim
prefetch_queue = slim.prefetch_queue
flags = tf.app.flags
FLAGS = flags.FLAGS
# Distributed training settings
flags.DEFINE_integer('num_clones', 1,
'Number of model clones to deploy to each worker replica.'
'This should be greater than one if you want to use '
'multiple GPUs located on a single machine.')
flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
flags.DEFINE_integer('num_replicas', 1,
'Number of worker replicas. This typically corresponds '
'to the number of machines you are training on. Note '
'that the training will be done asynchronously.')
flags.DEFINE_integer('startup_delay_steps', 15,
'Number of training steps between replicas startup.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker. It is '
'reccomended to use num_ps_tasks=num_replicas/2.')
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID. Should increment per worker '
'replica added to achieve between graph replication.')
# Training configuration settings
flags.DEFINE_string('config_path', '',
'Path to a pipeline_pb2.TrainEvalConfig config '
'file. If provided, other configs are ignored')
flags.mark_flag_as_required('config_path')
flags.DEFINE_string('logdir', '',
'Directory to save the checkpoints and training summaries.')
flags.mark_flag_as_required('logdir')
flags.DEFINE_integer('save_interval_secs', 600, # default to 5 min
'Time between successive saves of a checkpoint in secs.')
flags.DEFINE_integer('max_checkpoints_to_keep', 50, # might want to cut this down
'Number of checkpoints to keep in the `logdir`.')
flags.DEFINE_string('checkpoint_nodes', None,
'Names of nodes to use as checkpoint nodes for training.')
# Debug flag
flags.DEFINE_boolean('log_memory', False, '')
flags.DEFINE_boolean('image_summaries', False, '')
def main(_):
tf.gfile.MakeDirs(FLAGS.logdir)
pipeline_config = pipeline_pb2.PipelineConfig()
with tf.gfile.GFile(FLAGS.config_path, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
model_config = pipeline_config.model
train_config = pipeline_config.train_config
input_config = pipeline_config.train_input_reader
create_model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=True)
create_input_fn = functools.partial(
dataset_builder.build,
input_reader_config=input_config)
is_chief = (FLAGS.task == 0)
checkpoint_nodes = FLAGS.checkpoint_nodes
if checkpoint_nodes is None:
checkpoint_nodes = ICNET_GRADIENT_CHECKPOINTS
else:
checkpoint_nodes = checkpoint_nodes.replace(" ", "").split(",")
train_segmentation_model(
create_model_fn,
create_input_fn,
train_config,
master=FLAGS.master,
task=FLAGS.task,
is_chief=is_chief,
startup_delay_steps=FLAGS.startup_delay_steps,
train_dir=FLAGS.logdir,
num_clones=FLAGS.num_clones,
num_worker_replicas=FLAGS.num_replicas,
clone_on_cpu=FLAGS.clone_on_cpu,
replica_id=FLAGS.task,
num_replicas=FLAGS.num_replicas,
num_ps_tasks=FLAGS.num_ps_tasks,
max_checkpoints_to_keep=FLAGS.max_checkpoints_to_keep,
save_interval_secs=FLAGS.save_interval_secs,
image_summaries=FLAGS.image_summaries,
log_memory=FLAGS.log_memory,
gradient_checkpoints=checkpoint_nodes)
if __name__ == '__main__':
tf.app.run()