From d31aabb48bca1c0670c70c153b68a0a3cc12687a Mon Sep 17 00:00:00 2001 From: Jian Zhou <41574757+PatrickkZ@users.noreply.github.com> Date: Thu, 8 Sep 2022 14:50:40 +0800 Subject: [PATCH] fix issue 4642, fix DBFS file path problem on Dataricks (#5679) * fix issue 4642 * parse model_dir Co-authored-by: Zhou --- python/dllib/src/bigdl/dllib/utils/file_utils.py | 2 ++ python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/dllib/src/bigdl/dllib/utils/file_utils.py b/python/dllib/src/bigdl/dllib/utils/file_utils.py index 6e2701aa58c..ebd8aeaf904 100644 --- a/python/dllib/src/bigdl/dllib/utils/file_utils.py +++ b/python/dllib/src/bigdl/dllib/utils/file_utils.py @@ -70,6 +70,8 @@ def mkdirs(path): def is_local_path(path): + if path.startswith("/dbfs"): + return False parse_result = urlparse(path) return len(parse_result.scheme.lower()) == 0 or parse_result.scheme.lower() == "file" diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 34dda1ca363..a567aaf9e5d 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -47,6 +47,12 @@ logger = logging.getLogger(__name__) +def parse_model_dir(model_dir): + if model_dir and model_dir.startswith("dbfs:/"): + model_dir = "/dbfs/" + model_dir[len("dbfs:/"):] + return model_dir + + class SparkTFEstimator(): def __init__(self, model_creator, @@ -83,7 +89,7 @@ def __init__(self, invalidInputError(False, "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate function of the estimator instead.") - self.model_dir = model_dir + self.model_dir = parse_model_dir(model_dir) master = sc.getConf().get("spark.master") if not master.startswith("local"): logger.info("For cluster mode, make sure to use shared filesystem path "