diff --git a/python/dllib/src/bigdl/dllib/utils/file_utils.py b/python/dllib/src/bigdl/dllib/utils/file_utils.py index 6e2701aa58c..ebd8aeaf904 100644 --- a/python/dllib/src/bigdl/dllib/utils/file_utils.py +++ b/python/dllib/src/bigdl/dllib/utils/file_utils.py @@ -70,6 +70,8 @@ def mkdirs(path): def is_local_path(path): + if path.startswith("/dbfs"): + return False parse_result = urlparse(path) return len(parse_result.scheme.lower()) == 0 or parse_result.scheme.lower() == "file" diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 34dda1ca363..a567aaf9e5d 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -47,6 +47,12 @@ logger = logging.getLogger(__name__) +def parse_model_dir(model_dir): + if model_dir and model_dir.startswith("dbfs:/"): + model_dir = "/dbfs/" + model_dir[len("dbfs:/"):] + return model_dir + + class SparkTFEstimator(): def __init__(self, model_creator, @@ -83,7 +89,7 @@ def __init__(self, invalidInputError(False, "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate function of the estimator instead.") - self.model_dir = model_dir + self.model_dir = parse_model_dir(model_dir) master = sc.getConf().get("spark.master") if not master.startswith("local"): logger.info("For cluster mode, make sure to use shared filesystem path "