diff --git a/ludwig/data/postprocessing.py b/ludwig/data/postprocessing.py index 63c7f8606d6..3592795daab 100644 --- a/ludwig/data/postprocessing.py +++ b/ludwig/data/postprocessing.py @@ -24,6 +24,7 @@ from ludwig.data.utils import convert_to_dict from ludwig.utils.data_utils import DATAFRAME_FORMATS, DICT_FORMATS from ludwig.utils.dataframe_utils import to_numpy_dataset +from ludwig.utils.fs_utils import has_remote_protocol, open_file from ludwig.utils.misc_utils import get_from_registry from ludwig.utils.strings_utils import make_safe_filename @@ -66,7 +67,11 @@ def _save_as_numpy(predictions, output_directory, saved_keys, backend): for k, v in numpy_predictions.items(): k = k.replace("<", "[").replace(">", "]") # Replace and with [UNK], [PAD] if k not in saved_keys: - np.save(npy_filename.format(make_safe_filename(k)), v) + if has_remote_protocol(output_directory): + with open_file(npy_filename.format(make_safe_filename(k)), mode="wb") as f: + np.save(f, v) + else: + np.save(npy_filename.format(make_safe_filename(k)), v) saved_keys.add(k)