-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using keras for Distributed training raise RuntimeError("Graph is finalized and cannot be modified.") #3997
Comments
i have the same problem(Graph is finalized and cannot be modified.)can anyone help me? |
I use the TensorFlow instead of keras, and I run into the same problem. |
i have the same problem(Graph is finalized and cannot be modified.)can anyone help me? |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed. |
Did any one could make this code work? or is it even possible to make it work this way? |
The graph get's finalized by calling tf.train.Supervisor. We need to compile and call _make_train_function and if you need also make_test_function and make_predict_function bevor calling the supervisor. I updated the code above, it should run now. #!/usr/bin/env python
# -*- coding:utf-8 -*-
# Created by Enigma on 2016/9/26
import numpy as np
import tensorflow as tf
# Define Hyperparameters
FLAGS = tf.app.flags.FLAGS
# For missions
tf.app.flags.DEFINE_string("ps_hosts", "",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameters
from keras import backend as K
from keras.layers import Input, Dense
from keras.models import Model
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server_config = tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True),
log_device_placement=True)
server = tf.train.Server(cluster, config=server_config,
job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d/cpu:0" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.Variable(0, name='global_step', trainable=False)
inputs = Input(shape=[1, ])
hidden = Dense(10, activation='relu')(inputs)
output = Dense(1, activation='sigmoid')(hidden)
model = Model(input=inputs, output=output)
saver = tf.train.Saver()
model.compile(optimizer='sgd', loss='mse')
model._make_train_function()
model._make_test_function()
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="./checkpoint/",
saver=saver,
global_step=global_step,
save_model_secs=60)
with sv.managed_session(server.target) as sess:
step = 0
K.set_session(sess)
K.manual_variable_initialization(True)
while step < 1000000:
train_x = np.random.randn(1)
train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10
model.fit(train_x, train_y)
sv.stop()
if __name__ == "__main__":
tf.app.run() |
@PBehr This is awesome! Any ideas how to increment global step? |
I'm getting an error (below) from your code which I think has to do with the feed_dict. Any idea how to solve this? Did you run into this with your code?
|
I'm using keras for distributed training with following code:
then I run it with:
it doesn't work and return
I wondering if it happens because keras' model wasn't created as part of the graph used in tf.train.Supervisor, but I have not a clue on how to prove it or fix it. Any idea?
The text was updated successfully, but these errors were encountered: