-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
364 lines (304 loc) · 13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import click
import json
import os
import sys
import tensorflow as tf
#tf.enable_eager_execution()
import time
from tensorflow.python import debug as tf_debug
from luminoth.datasets import get_dataset
from luminoth.datasets.exceptions import InvalidDataDirectory
from luminoth.models import get_model
from luminoth.utils.config import get_config
from luminoth.utils.hooks import ImageVisHook, VarVisHook
from luminoth.utils.training import get_optimizer, clip_gradients_by_norm
from luminoth.utils.experiments import save_run
import pause
def run(config, target='', cluster_spec=None, is_chief=True, job_name=None,
task_index=None, get_model_fn=get_model, get_dataset_fn=get_dataset,
environment=None):
model_class = get_model_fn(config.model.type)
image_vis = config.train.get('image_vis')
var_vis = config.train.get('var_vis')
if config.train.get('seed') is not None:
tf.set_random_seed(config.train.seed)
log_prefix = '[{}-{}] - '.format(job_name, task_index) \
if job_name is not None and task_index is not None else ''
if config.train.debug or config.train.tf_debug:
tf.logging.set_verbosity(tf.logging.DEBUG)
else:
tf.logging.set_verbosity(tf.logging.INFO)
model = model_class(config)
#print("model construct end !!!!")
#pause.seconds(100000)
# Placement of ops on devices using replica device setter
# which automatically places the parameters on the `ps` server
# and the `ops` on the workers
#
# See:
# https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
try:
config['dataset']['type']
except KeyError:
raise KeyError('dataset.type should be set on the custom config.')
try:
dataset_class = get_dataset_fn(config.dataset.type)
dataset = dataset_class(config)
train_dataset = dataset()
except InvalidDataDirectory as exc:
tf.logging.error(
"Error while reading dataset, {}".format(exc)
)
sys.exit(1)
train_image = train_dataset['image']
train_filename = train_dataset['filename']
train_bboxes = train_dataset['bboxes']
prediction_dict = model(train_image, train_bboxes, is_training=True)
total_loss = model.loss(prediction_dict)
if hasattr(model, "partial_reduce_pred_list"):
print("perform partial reduce !!!!!")
prediction_dict = model.partial_reduce_pred_list(prediction_dict)
global_step = tf.train.get_or_create_global_step()
optimizer = get_optimizer(config.train, global_step)
# TODO: Is this necesarry? Couldn't we just get them from the
# trainable vars collection? We should probably improve our
# usage of collections.
trainable_vars = model.get_trainable_vars()
# Compute, clip and apply gradients
with tf.name_scope('gradients'):
grads_and_vars = optimizer.compute_gradients(
total_loss, trainable_vars
)
if config.train.clip_by_norm:
grads_and_vars = clip_gradients_by_norm(grads_and_vars)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.apply_gradients(
grads_and_vars, global_step=global_step
)
# Create custom init for slots in optimizer, as we don't save them to
# our checkpoints. An example of slots in an optimizer are the Momentum
# variables in MomentumOptimizer. We do this because slot variables can
# effectively duplicate the size of your checkpoint!
slot_variables = [
optimizer.get_slot(var, name)
for name in optimizer.get_slot_names()
for var in trainable_vars
]
slot_variables = list(filter(lambda var: var, slot_variables))
slot_init = tf.variables_initializer(
slot_variables,
name='optimizer_slots_initializer'
)
# Create saver for saving/restoring model
model_saver = tf.train.Saver(
set(tf.global_variables()) - set(slot_variables),
name='model_saver',
max_to_keep=config.train.get('checkpoints_max_keep', 1),
)
# Create saver for loading pretrained checkpoint into base network
base_checkpoint_vars = model.get_base_network_checkpoint_vars()
checkpoint_file = model.get_checkpoint_file()
if base_checkpoint_vars and checkpoint_file:
base_net_checkpoint_saver = tf.train.Saver(
base_checkpoint_vars,
name='base_net_checkpoint_saver'
)
# We'll send this fn to Scaffold init_fn
def load_base_net_checkpoint(_, session):
base_net_checkpoint_saver.restore(
session, checkpoint_file
)
else:
load_base_net_checkpoint = None
tf.logging.info('{}Starting training for {}'.format(log_prefix, model))
run_options = None
if config.train.full_trace:
run_options = tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE
)
# Create custom Scaffold to make sure we run our own init_op when model
# is not restored from checkpoint.
summary_op = [model.summary]
summaries = tf.summary.merge_all()
if summaries is not None:
summary_op.append(summaries)
summary_op = tf.summary.merge(summary_op)
# `ready_for_local_init_op` is hardcoded to 'ready' as local init doesn't
# depend on global init and `local_init_op` only runs when it is set as
# 'ready' (an empty string tensor sets it as ready).
scaffold = tf.train.Scaffold(
saver=model_saver,
init_op=tf.global_variables_initializer() if is_chief else tf.no_op(),
local_init_op=tf.group(tf.initialize_local_variables(), slot_init),
ready_for_local_init_op=tf.constant([], dtype=tf.string),
summary_op=summary_op,
init_fn=load_base_net_checkpoint,
)
# Custom hooks for our session
hooks = []
chief_only_hooks = []
if config.train.tf_debug:
debug_hook = tf_debug.LocalCLIDebugHook()
debug_hook.add_tensor_filter(
'has_inf_or_nan', tf_debug.has_inf_or_nan
)
hooks.extend([debug_hook])
if not config.train.job_dir:
tf.logging.warning(
'`job_dir` is not defined. Checkpoints and logs will not be saved.'
)
checkpoint_dir = None
elif config.train.run_name:
# Use run_name when available
checkpoint_dir = os.path.join(
config.train.job_dir, config.train.run_name
)
else:
checkpoint_dir = config.train.job_dir
should_add_hooks = (
config.train.display_every_steps
or config.train.display_every_secs
and checkpoint_dir is not None
)
if should_add_hooks:
if not config.train.debug and image_vis == 'debug':
tf.logging.warning('ImageVisHook will not run without debug mode.')
elif image_vis is not None:
# ImageVis only runs on the chief.
#if "prediction_1_dict" in prediction_dict:
if type(prediction_dict) == type([]):
if hasattr(model, "partial_reduce_pred_list"):
prediction_dict = prediction_dict[0]
else:
prediction_dict = prediction_dict[1]
chief_only_hooks.append(
ImageVisHook(
prediction_dict,
image=prediction_dict["image"],
gt_bboxes=prediction_dict["gt_boxes"],
config=config.model,
output_dir=checkpoint_dir,
every_n_steps=config.train.display_every_steps,
every_n_secs=config.train.display_every_secs,
image_visualization_mode=image_vis
)
)
else:
chief_only_hooks.append(
ImageVisHook(
prediction_dict,
image=train_dataset['image'],
gt_bboxes=train_dataset['bboxes'],
config=config.model,
output_dir=checkpoint_dir,
every_n_steps=config.train.display_every_steps,
every_n_secs=config.train.display_every_secs,
image_visualization_mode=image_vis
)
)
if var_vis is not None:
# VarVis only runs on the chief.
chief_only_hooks.append(
VarVisHook(
every_n_steps=config.train.display_every_steps,
every_n_secs=config.train.display_every_secs,
mode=var_vis,
output_dir=checkpoint_dir,
vars_summary=model.vars_summary,
)
)
step = -1
with tf.train.MonitoredTrainingSession(
master=target,
is_chief=is_chief,
checkpoint_dir=checkpoint_dir,
scaffold=scaffold,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=config.train.save_checkpoint_secs,
save_summaries_steps=config.train.save_summaries_steps,
save_summaries_secs=config.train.save_summaries_secs,
) as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
before = time.time()
_, train_loss, step, filename = sess.run([
train_op, total_loss, global_step, train_filename
], options=run_options)
# TODO: Add image summary every once in a while.
tf.logging.info(
'{}step: {}, file: {}, train_loss: {}, in {:.2f}s'.format(
log_prefix, step, filename, train_loss,
time.time() - before
))
if is_chief and step == 1:
# We save the run after first batch to make sure everything
# works properly.
save_run(config, environment=environment)
except tf.errors.OutOfRangeError:
tf.logging.info(
'{}finished training after {} epoch limit'.format(
log_prefix, config.train.num_epochs
)
)
# TODO: Print summary
finally:
coord.request_stop()
# Wait for all threads to stop.
coord.join(threads)
return step
@click.command(help='Train models')
@click.option('config_files', '--config', '-c', required=True, multiple=True, help='Config to use.') # noqa
@click.option('--job-dir', help='Job directory.')
@click.option('override_params', '--override', '-o', multiple=True, help='Override model config params.') # noqa
def train(config_files, job_dir, override_params):
"""
Parse TF_CONFIG to cluster_spec and call run() function
"""
# TF_CONFIG environment variable is available when running using gcloud
# either locally or on cloud. It has all the information required to create
# a ClusterSpec which is important for running distributed code.
tf_config_val = os.environ.get('TF_CONFIG')
if tf_config_val:
tf_config = json.loads(tf_config_val)
else:
tf_config = {}
cluster = tf_config.get('cluster')
job_name = tf_config.get('task', {}).get('type')
task_index = tf_config.get('task', {}).get('index')
environment = tf_config.get('environment', 'local')
# Get the user config and the model type from it.
try:
config = get_config(config_files, override_params=override_params)
except KeyError:
# Without mode type defined we can't use the default config settings.
raise KeyError('model.type should be set on the custom config.')
if job_dir:
override_params += ('train.job_dir={}'.format(job_dir), )
# If cluster information is empty or TF_CONFIG is not available, run local
if job_name is None or task_index is None:
return run(
config, environment=environment
)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(
cluster_spec, job_name=job_name, task_index=task_index)
# Wait for incoming connections forever
# Worker ships the graph to the ps server
# The ps server manages the parameters of the model.
if job_name == 'ps':
server.join()
return
elif job_name in ['master', 'worker']:
is_chief = job_name == 'master'
return run(
config, target=server.target, cluster_spec=cluster_spec,
is_chief=is_chief, job_name=job_name, task_index=task_index,
environment=environment
)
if __name__ == '__main__':
train()