Skip to content

Commit

Permalink
Add an explicit call to merge_wtih_defaults() when loading a config f…
Browse files Browse the repository at this point in the history
…rom a model directory. (#2226)

* Add call to merge_wtih_defaults in api.py.

* Add a test that loads an old model (from S3) and uses it to run prediction.

* Use zipped file + wget to load old Ludwig model from a zipped file, hosted on S3.
  • Loading branch information
justinxzhao authored Jul 6, 2022
1 parent f61ad2d commit 5d6a970
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,9 @@ def load(

config = backend.broadcast_return(lambda: load_json(os.path.join(model_dir, MODEL_HYPERPARAMETERS_FILE_NAME)))

# Upgrades deprecated fields and adds new required fields, in case the config loaded from disk is old.
config = merge_with_defaults(config)

if backend_param is None and "backend" in config:
# Reset backend from config
backend = initialize_backend(config.get("backend"))
Expand Down
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pytest
pytest-timeout
wget
six>=1.13.0
aim
wandb<0.12.11
Expand Down
34 changes: 34 additions & 0 deletions tests/regression_tests/model/test_old_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import zipfile

import pandas as pd
import wget

from ludwig.api import LudwigModel


def test_model_loaded_from_old_config_prediction_works(tmpdir):
# Titanic model based on 0.5.3.
old_model_url = "https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/old_model.zip"
old_model_filename = wget.download(old_model_url, tmpdir)
with zipfile.ZipFile(old_model_filename, "r") as zip_ref:
zip_ref.extractall(tmpdir)
example_data = {
"PassengerId": 892,
"Pclass": 3,
"Name": "Kelly, Mr. James",
"Sex": "male",
"Age": 34.5,
"SibSp": 0,
"Parch": 0,
"Ticket": "330911",
"Fare": 7.8292,
"Cabin": None,
"Embarked": "Q",
}
test_set = pd.DataFrame(example_data, index=[0])

ludwig_model = LudwigModel.load(os.path.join(tmpdir, "old_model/model"))
predictions, _ = ludwig_model.predict(dataset=test_set)

assert predictions.to_dict()["Survived_predictions"] == {0: False}

0 comments on commit 5d6a970

Please sign in to comment.