Skip to content

Commit

Permalink
fix databricks dbfs file path (#4674)
Browse files Browse the repository at this point in the history
  • Loading branch information
Le-Zheng authored May 31, 2022
1 parent 6ae9938 commit 4182f05
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/orca/src/bigdl/orca/learn/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def to_dataset(data, batch_size, batch_per_thread, validation_data,
return dataset


def save_model_dir(model_dir):
if model_dir.startswith("dbfs:/"):
model_dir = "/dbfs/" + model_dir[len("dbfs:/"):]
return model_dir


class TensorFlowEstimator(Estimator):
def __init__(self, *, inputs, outputs, labels, loss,
optimizer, clip_norm, clip_value,
Expand Down Expand Up @@ -775,6 +781,8 @@ def shutdown(self):

class KerasEstimator(Estimator):
def __init__(self, keras_model, metrics, model_dir, optimizer):
if model_dir and model_dir.startswith("dbfs:/"):
model_dir = save_model_dir(model_dir)
self.model = KerasModel(keras_model, model_dir)
self.load_checkpoint = False
self.metrics = metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from bigdl.orca.learn.tf.estimator import Estimator
from bigdl.dllib.nncontext import *
from bigdl.orca.learn.utils import convert_predict_rdd_to_dataframe
from bigdl.orca.learn.tf.estimator import save_model_dir


class TestEstimatorForKeras(TestCase):
Expand Down Expand Up @@ -681,6 +682,11 @@ def test_estimator_keras_get_model(self):
validation_data=df)
assert est.get_model() is model

def test_model_path_dbfs_from_keras(self):
model_dir = "dbfs:/FileStore/shared_uploads/models"
processed_model_dir = save_model_dir(model_dir)
assert processed_model_dir == "/dbfs/FileStore/shared_uploads/models"


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 4182f05

Please sign in to comment.