Skip to content

Commit

Permalink
Support saving numpy predictions to remote FS (#2245)
Browse files Browse the repository at this point in the history
  • Loading branch information
hungcs authored Jul 9, 2022
1 parent ed7967f commit 5c9cffb
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ludwig/data/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ludwig.data.utils import convert_to_dict
from ludwig.utils.data_utils import DATAFRAME_FORMATS, DICT_FORMATS
from ludwig.utils.dataframe_utils import to_numpy_dataset
from ludwig.utils.fs_utils import has_remote_protocol, open_file
from ludwig.utils.misc_utils import get_from_registry
from ludwig.utils.strings_utils import make_safe_filename

Expand Down Expand Up @@ -66,7 +67,11 @@ def _save_as_numpy(predictions, output_directory, saved_keys, backend):
for k, v in numpy_predictions.items():
k = k.replace("<", "[").replace(">", "]") # Replace <UNK> and <PAD> with [UNK], [PAD]
if k not in saved_keys:
np.save(npy_filename.format(make_safe_filename(k)), v)
if has_remote_protocol(output_directory):
with open_file(npy_filename.format(make_safe_filename(k)), mode="wb") as f:
np.save(f, v)
else:
np.save(npy_filename.format(make_safe_filename(k)), v)
saved_keys.add(k)


Expand Down

0 comments on commit 5c9cffb

Please sign in to comment.