Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add big TF model support #2974

Merged
merged 32 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from zoo import init_nncontext
from zoo.orca.data.tf.data import Dataset
from zoo.orca.learn.tf.estimator import Estimator
from zoo.orca.learn.tf.utils import save_tf_checkpoint, load_tf_checkpoint, get_checkpoint_state
from zoo.util.tf import save_tf_checkpoint, load_tf_checkpoint, get_checkpoint_state

resource_path = os.path.join(os.path.split(__file__)[0], "../../../resources")

Expand Down
1 change: 1 addition & 0 deletions pyzoo/zoo/orca/learn/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from zoo.tfpark.tf_optimizer import StatelessMetric
from zoo.tfpark.utils import evaluate_metrics
from zoo.util import nest
from zoo.util.tf import save_tf_checkpoint


class Estimator(object):
Expand Down
151 changes: 29 additions & 122 deletions pyzoo/zoo/orca/learn/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,125 +105,32 @@ def transform_predict(iter):
return SparkXShards(prediction_rdd.mapPartitions(transform_predict))


def save_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Save tf checkpoint without using native tensorflow remote access method.
:param sess: tf session to be saved.
:param checkpoint_path: checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tf saver to save checkpoint
"""
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.save(sess, checkpoint_path)
else:
ckpt_name = basename(checkpoint_path)
remote_dir = dirname(checkpoint_path)
# save to local checkpoint
temp = tempfile.mkdtemp()
if saver is None:
saver = tf.train.Saver()
saver.save(sess, join(temp, ckpt_name))
# change checkpoint file
with open(join(temp, "checkpoint")) as f:
new_lines = []
lines = f.readlines()
# replace model_checkpoint_path and all_model_checkpoint_paths to checkpoint name
# instead of the absolute checkpoint path
for line in lines:
if re.compile("^model_checkpoint_path: \"(.*)\"$").match(line):
new_lines.append("model_checkpoint_path: \"{}\"\n".format(ckpt_name))
elif re.compile("^all_model_checkpoint_paths: \"(.*)\"$").match(line):
new_lines.append("all_model_checkpoint_paths: \"{}\"\n".format(ckpt_name))
else:
new_lines.append(line)
with open(join(temp, "checkpoint"), 'w') as f:
f.writelines(new_lines)
# move to remote
[put_local_file_to_remote(join(temp, file), join(remote_dir, file), over_write=True)
for file in os.listdir(temp)]
shutil.rmtree(temp)


def get_checkpoint_state(checkpoint_dir):
"""
Get tf checkpoint state from checkpoint directory without using native tensorflow accessing
remote method.
:param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems.
:return: tf checkpoint protobuf
"""
if is_local_path(checkpoint_dir):
return tf.train.get_checkpoint_state(checkpoint_dir)
else:
# check if checkpoint file exists
file_list = get_file_list(checkpoint_dir)
has_checkpoint = False
for file in file_list:
if basename(file) == 'checkpoint':
has_checkpoint = True
break
if not has_checkpoint:
return None
# get checkpoint file
temp = tempfile.mkdtemp()
get_remote_file_to_local(join(checkpoint_dir, "checkpoint"), join(temp, "checkpoint"))
ckpt_name = None
with open(join(temp, "checkpoint")) as f:
lines = f.readlines()
# get checkpoint name from 'checkpoint' file
for line in lines:
m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line)
if m:
ckpt_name = m.group(1)
break
if ckpt_name is None:
shutil.rmtree(temp)
return None
# filter checkpoint files
checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
if not checkpoint_files:
shutil.rmtree(temp)
return None
# get checkpoint files to local
[get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
# get checkpoint state
ckpt = tf.train.get_checkpoint_state(temp)
if not ckpt:
shutil.rmtree(temp)
return None
ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name)
ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)]
shutil.rmtree(temp)
return ckpt


def load_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing
remote method.
:param sess: tensorflow session to be loaded to.
:param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tensorflow saver to load checkpoint
"""
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
else:
ckpt_name = basename(checkpoint_path)
checkpoint_dir = dirname(checkpoint_path)
# get remote file lists
file_list = get_file_list(checkpoint_dir)
# filter checkpoint files
checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
# get checkpoint files to local
temp = tempfile.mkdtemp()
[get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
if saver is None:
saver = tf.train.Saver()
try:
saver.restore(sess, join(temp, ckpt_name))
except Exception as e:
raise e
finally:
shutil.rmtree(temp)
def find_latest_checkpoint(model_dir):
import os
import re
import datetime
ckpt_path = None
latest_version = None
for (root, dirs, files) in os.walk(model_dir, topdown=True):
temp_versions = []
timestamps = []
for dir in dirs:
if re.match('(\d{4})-(\d{2})-(\d{2})_(\d{2})-(\d{2})-(\d{2})$', dir) is not None:
try:
# check if dir name is date time
datetime.datetime.strptime(dir, '%Y-%m-%d_%H-%M-%S')
timestamps.append(dir)
except:
continue
if timestamps:
start_dir = os.path.join(root, max(timestamps))
return find_latest_checkpoint(start_dir)
for file_name in files:
if re.match("^optimMethod-TFParkTraining\.[0-9]+$", file_name) is not None:
version = int(file_name.split(".")[1])
temp_versions.append(version)
if temp_versions:
ckpt_path = root
latest_version = max(temp_versions)
break
return ckpt_path, latest_version
144 changes: 144 additions & 0 deletions pyzoo/zoo/util/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import json
import copy

import tempfile
import re
import shutil
from zoo.common.utils import put_local_file_to_remote, get_remote_file_to_local, get_file_list,\
is_local_path


def process_grad(grad):
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -309,3 +315,141 @@ def _append_port(input_name):
return input_name + ":0"
else:
return input_name


def change_path_in_tf_checkpoint(checkpoint_path, ckpt_name):
# change checkpoint file
with open(checkpoint_path) as f:
import re
new_lines = []
lines = f.readlines()
# replace model_checkpoint_path and all_model_checkpoint_paths to checkpoint name
# instead of the absolute checkpoint path
for line in lines:
if re.compile("^model_checkpoint_path: \"(.*)\"$").match(line):
new_lines.append("model_checkpoint_path: \"{}\"\n".format(ckpt_name))
elif re.compile("^all_model_checkpoint_paths: \"(.*)\"$").match(line):
new_lines.append("all_model_checkpoint_paths: \"{}\"\n".format(ckpt_name))
else:
new_lines.append(line)
with open(checkpoint_path, 'w') as f:
f.writelines(new_lines)


def save_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Save tf checkpoint without using native tensorflow remote access method.
:param sess: tf session to be saved.
:param checkpoint_path: checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tf saver to save checkpoint
"""
import tensorflow as tf
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.save(sess, checkpoint_path)
else:
ckpt_name = os.path.basename(checkpoint_path)
remote_dir = os.path.dirname(checkpoint_path)
# save to local checkpoint
temp = tempfile.mkdtemp()
if saver is None:
saver = tf.train.Saver()
saver.save(sess, os.path.join(temp, ckpt_name))
change_path_in_tf_checkpoint(os.path.join(temp, "checkpoint"), ckpt_name)
# move to remote
[put_local_file_to_remote(os.path.join(temp, file), os.path.join(remote_dir, file),
over_write=True)
for file in os.listdir(temp)]
shutil.rmtree(temp)


def get_checkpoint_state(checkpoint_dir):
"""
Get tf checkpoint state from checkpoint directory without using native tensorflow accessing
remote method.
:param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems.
:return: tf checkpoint protobuf
"""
import tensorflow as tf
if is_local_path(checkpoint_dir):
return tf.train.get_checkpoint_state(checkpoint_dir)
else:
# check if checkpoint file exists
file_list = get_file_list(checkpoint_dir)
has_checkpoint = False
for file in file_list:
if os.path.basename(file) == 'checkpoint':
has_checkpoint = True
break
if not has_checkpoint:
return None
# get checkpoint file
temp = tempfile.mkdtemp()
get_remote_file_to_local(os.path.join(checkpoint_dir, "checkpoint"),
os.path.join(temp, "checkpoint"))
ckpt_name = None
with open(os.path.join(temp, "checkpoint")) as f:
lines = f.readlines()
# get checkpoint name from 'checkpoint' file
for line in lines:
m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line)
if m:
ckpt_name = m.group(1)
break
if ckpt_name is None:
shutil.rmtree(temp)
return None
# filter checkpoint files
checkpoint_files = [file for file in file_list
if os.path.basename(file).startswith(ckpt_name)]
if not checkpoint_files:
shutil.rmtree(temp)
return None
# get checkpoint files to local
[get_remote_file_to_local(file, os.path.join(temp, os.path.basename(file)))
for file in checkpoint_files]
# get checkpoint state
ckpt = tf.train.get_checkpoint_state(temp)
if not ckpt:
shutil.rmtree(temp)
return None
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt_name)
ckpt.all_model_checkpoint_paths[:] = [os.path.join(checkpoint_dir, ckpt_name)]
shutil.rmtree(temp)
return ckpt


def load_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing
remote method.
:param sess: tensorflow session to be loaded to.
:param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tensorflow saver to load checkpoint
"""
import tensorflow as tf
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
else:
ckpt_name = os.path.basename(checkpoint_path)
checkpoint_dir = os.path.dirname(checkpoint_path)
# get remote file lists
file_list = get_file_list(checkpoint_dir)
# filter checkpoint files
checkpoint_files = [file for file in file_list
if os.path.basename(file).startswith(ckpt_name)]
# get checkpoint files to local
temp = tempfile.mkdtemp()
[get_remote_file_to_local(file, os.path.join(temp, os.path.basename(file)))
for file in checkpoint_files]
if saver is None:
saver = tf.train.Saver()
try:
saver.restore(sess, os.path.join(temp, ckpt_name))
except Exception as e:
raise e
finally:
shutil.rmtree(temp)
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,13 @@ object InternalDistriOptimizer {
iter.next().localModels.head.asInstanceOf[TFTrainingHelperV2].moveWeightsOutOfTF()
Iterator.single(1)
}).reduce(_ + _)
val extraState = models.map(_.localModels.head.getExtraParameter()).first()

val extraParamLength = models.map(_.localModels.head.getExtraParameter().length).first()
val extraState = new Array[Tensor[T]](extraParamLength)
(0 until extraParamLength).foreach(i =>
extraState(i) = models.map(_.localModels.head.getExtraParameter()(i)).first()
)
// val extraState = models.map(_.localModels.head.getExtraParameter()).first()
trainingModel.setExtraParameter(extraState)

// make sure gradient is as the same length as weight
Expand All @@ -1820,7 +1826,6 @@ object InternalDistriOptimizer {
val (parameter, gradientParameter) =
InternalOptimizerUtil.getParametersFromModel(trainingModel)


val (weights, gradients) = models.mapPartitions(iter => {
val cached = iter.next()
val curPartitionId = TaskContext.getPartitionId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class TFNet(private val graphDef: TFGraphHolder,
implicit val ev = TensorNumeric.NumericFloat
implicit val tag: ClassTag[Float] = ClassTag.Float

System.setProperty("bigdl.ModelBroadcastFactory",
"com.intel.analytics.zoo.tfpark.TFModelBroadcastFactory")

@transient
private lazy val tensorManager = new TFResourceManager()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ private[zoo] class TFNetForInference(graphRunner: GraphRunner,
implicit val ev = TensorNumeric.NumericFloat
implicit val tag: ClassTag[Float] = ClassTag.Float

System.setProperty("bigdl.ModelBroadcastFactory",
"com.intel.analytics.zoo.tfpark.TFModelBroadcastFactory")

override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
(weights, gradWeights)
}
Expand Down
Loading