Skip to content

Commit

Permalink
Add big TF model support (intel-analytics#2974)
Browse files Browse the repository at this point in the history
* test

* add print

* add graphdef print

* update

* update

* test

* update

* add print

* update broadcast

* fix spark file location

* update broadcast

* fix extra init

* update broadcast

* update property

* update broadcast

* restore broadcast

* restore clone

* fix clone

* fix get extra

* update collect weights

* update

* update

* update property

* update get extra param

* update

* update

* restore

* remove unused import

* fix style

* add methods

* fix style
  • Loading branch information
jenniew authored and Wang, Yang committed Sep 26, 2021
1 parent 524a399 commit cf5bc41
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 123 deletions.
1 change: 1 addition & 0 deletions python/orca/src/bigdl/orca/learn/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from zoo.tfpark.tf_optimizer import StatelessMetric
from zoo.tfpark.utils import evaluate_metrics
from zoo.util import nest
from zoo.util.tf import save_tf_checkpoint


class Estimator(object):
Expand Down
151 changes: 29 additions & 122 deletions python/orca/src/bigdl/orca/learn/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,125 +105,32 @@ def transform_predict(iter):
return SparkXShards(prediction_rdd.mapPartitions(transform_predict))


def save_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Save tf checkpoint without using native tensorflow remote access method.
:param sess: tf session to be saved.
:param checkpoint_path: checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tf saver to save checkpoint
"""
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.save(sess, checkpoint_path)
else:
ckpt_name = basename(checkpoint_path)
remote_dir = dirname(checkpoint_path)
# save to local checkpoint
temp = tempfile.mkdtemp()
if saver is None:
saver = tf.train.Saver()
saver.save(sess, join(temp, ckpt_name))
# change checkpoint file
with open(join(temp, "checkpoint")) as f:
new_lines = []
lines = f.readlines()
# replace model_checkpoint_path and all_model_checkpoint_paths to checkpoint name
# instead of the absolute checkpoint path
for line in lines:
if re.compile("^model_checkpoint_path: \"(.*)\"$").match(line):
new_lines.append("model_checkpoint_path: \"{}\"\n".format(ckpt_name))
elif re.compile("^all_model_checkpoint_paths: \"(.*)\"$").match(line):
new_lines.append("all_model_checkpoint_paths: \"{}\"\n".format(ckpt_name))
else:
new_lines.append(line)
with open(join(temp, "checkpoint"), 'w') as f:
f.writelines(new_lines)
# move to remote
[put_local_file_to_remote(join(temp, file), join(remote_dir, file), over_write=True)
for file in os.listdir(temp)]
shutil.rmtree(temp)


def get_checkpoint_state(checkpoint_dir):
"""
Get tf checkpoint state from checkpoint directory without using native tensorflow accessing
remote method.
:param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems.
:return: tf checkpoint protobuf
"""
if is_local_path(checkpoint_dir):
return tf.train.get_checkpoint_state(checkpoint_dir)
else:
# check if checkpoint file exists
file_list = get_file_list(checkpoint_dir)
has_checkpoint = False
for file in file_list:
if basename(file) == 'checkpoint':
has_checkpoint = True
break
if not has_checkpoint:
return None
# get checkpoint file
temp = tempfile.mkdtemp()
get_remote_file_to_local(join(checkpoint_dir, "checkpoint"), join(temp, "checkpoint"))
ckpt_name = None
with open(join(temp, "checkpoint")) as f:
lines = f.readlines()
# get checkpoint name from 'checkpoint' file
for line in lines:
m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line)
if m:
ckpt_name = m.group(1)
break
if ckpt_name is None:
shutil.rmtree(temp)
return None
# filter checkpoint files
checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
if not checkpoint_files:
shutil.rmtree(temp)
return None
# get checkpoint files to local
[get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
# get checkpoint state
ckpt = tf.train.get_checkpoint_state(temp)
if not ckpt:
shutil.rmtree(temp)
return None
ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name)
ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)]
shutil.rmtree(temp)
return ckpt


def load_tf_checkpoint(sess, checkpoint_path, saver=None):
"""
Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing
remote method.
:param sess: tensorflow session to be loaded to.
:param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems.
:param saver: tensorflow saver to load checkpoint
"""
if is_local_path(checkpoint_path):
if saver is None:
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
else:
ckpt_name = basename(checkpoint_path)
checkpoint_dir = dirname(checkpoint_path)
# get remote file lists
file_list = get_file_list(checkpoint_dir)
# filter checkpoint files
checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
# get checkpoint files to local
temp = tempfile.mkdtemp()
[get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
if saver is None:
saver = tf.train.Saver()
try:
saver.restore(sess, join(temp, ckpt_name))
except Exception as e:
raise e
finally:
shutil.rmtree(temp)
def find_latest_checkpoint(model_dir):
import os
import re
import datetime
ckpt_path = None
latest_version = None
for (root, dirs, files) in os.walk(model_dir, topdown=True):
temp_versions = []
timestamps = []
for dir in dirs:
if re.match('(\d{4})-(\d{2})-(\d{2})_(\d{2})-(\d{2})-(\d{2})$', dir) is not None:
try:
# check if dir name is date time
datetime.datetime.strptime(dir, '%Y-%m-%d_%H-%M-%S')
timestamps.append(dir)
except:
continue
if timestamps:
start_dir = os.path.join(root, max(timestamps))
return find_latest_checkpoint(start_dir)
for file_name in files:
if re.match("^optimMethod-TFParkTraining\.[0-9]+$", file_name) is not None:
version = int(file_name.split(".")[1])
temp_versions.append(version)
if temp_versions:
ckpt_path = root
latest_version = max(temp_versions)
break
return ckpt_path, latest_version
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from zoo import init_nncontext
from zoo.orca.data.tf.data import Dataset
from zoo.orca.learn.tf.estimator import Estimator
from zoo.orca.learn.tf.utils import save_tf_checkpoint, load_tf_checkpoint, get_checkpoint_state
from zoo.util.tf import save_tf_checkpoint, load_tf_checkpoint, get_checkpoint_state

resource_path = os.path.join(os.path.split(__file__)[0], "../../../resources")

Expand Down

0 comments on commit cf5bc41

Please sign in to comment.