You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.
I am using code built on top of train.py and infer.py (files unchanged) from the seq2seq tutorial. I want to do transfer learning/loading from checkpoints but am unfamiliar with the tf.contrib.learn.Estimator and seq2seq.contrib.experiment environment.
I basically want to incorporate the checkpoint load step from infer.py into training:
saver = tf.train.Saver()
checkpoint_path = FLAGS.checkpoint_path
if not checkpoint_path:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
def session_init_op(_scaffold, sess):
saver.restore(sess, checkpoint_path)
tf.logging.info("Restored model from %s", checkpoint_path)
How/where along the pipeline should I be inserting the script?
def create_experiment(output_dir):
"""
Creates a new Experiment instance.
Args:
output_dir: Output directory for model checkpoints and summaries.
"""
config = run_config.RunConfig(
tf_random_seed=FLAGS.tf_random_seed,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
config.tf_config.log_device_placement = FLAGS.log_device_placement
train_options = training_utils.TrainOptions(
model_class=FLAGS.model,
model_params=FLAGS.model_params)
# On the main worker, save training options
if config.is_chief:
gfile.MakeDirs(output_dir)
train_options.dump(output_dir)
bucket_boundaries = None
if FLAGS.buckets:
bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))
# Training data input pipeline
train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_train,
mode=tf.contrib.learn.ModeKeys.TRAIN)
# Create training input function
train_input_fn = training_utils.create_input_fn(
pipeline=train_input_pipeline,
batch_size=FLAGS.batch_size,
bucket_boundaries=bucket_boundaries,
scope="train_input_fn")
# Development data input pipeline
dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_dev,
mode=tf.contrib.learn.ModeKeys.EVAL,
shuffle=False, num_epochs=1)
# Create eval input function
eval_input_fn = training_utils.create_input_fn(
pipeline=dev_input_pipeline,
batch_size=FLAGS.batch_size,
allow_smaller_final_batch=True,
scope="dev_input_fn")
def model_fn(features, labels, params, mode):
"""Builds the model graph"""
model = _create_from_dict({
"class": train_options.model_class,
"params": train_options.model_params
}, models, mode=mode)
return model(features, labels, params)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=output_dir,
config=config,
params=FLAGS.model_params)
# Create hooks
train_hooks = []
for dict_ in FLAGS.hooks:
hook = _create_from_dict(
dict_, hooks,
model_dir=estimator.model_dir,
run_config=config)
train_hooks.append(hook)
# Create metrics
eval_metrics = {}
for dict_ in FLAGS.metrics:
metric = _create_from_dict(dict_, metric_specs)
eval_metrics[metric.name] = metric
saver = tf.train.Saver()
checkpoint_path = FLAGS.checkpoint_path
if not checkpoint_path:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
saver.restore(sess, checkpoint_path)
## what is PatchedExperiment
experiment = PatchedExperiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
min_eval_frequency=FLAGS.eval_every_n_steps,
train_steps=FLAGS.train_steps,
eval_steps=None,
eval_metrics=eval_metrics,
train_monitors=train_hooks)
return experiment
The text was updated successfully, but these errors were encountered:
nweir127
changed the title
seq2seq checkpoint restore for fine-tuning
seq2seq checkpoint restore for transfer learning
Jun 11, 2019
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
I am using code built on top of train.py and infer.py (files unchanged) from the seq2seq tutorial. I want to do transfer learning/loading from checkpoints but am unfamiliar with the
tf.contrib.learn.Estimator
andseq2seq.contrib.experiment
environment.I basically want to incorporate the checkpoint load step from
infer.py
into training:How/where along the pipeline should I be inserting the script?
The text was updated successfully, but these errors were encountered: