Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save huggingface checkpoint as artifact in mlflow callback #17686

Merged
merged 7 commits into from
Jun 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def setup(self, args, state, model):
Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
Expand Down Expand Up @@ -872,12 +872,20 @@ def on_log(self, args, state, control, logs, model=None, **kwargs):

def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)
if self._auto_end_run and self._ml_flow.active_run():
self._ml_flow.end_run()

def on_save(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero and self._log_artifacts:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this new behavior (uploading the saved model at each save) should be documented above with the HF_MLFLOW_LOG_ARTIFACTS env variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed the documentation - thanks for your review!

ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
self._ml_flow.pyfunc.log_model(
ckpt_dir,
artifacts={"model_path": artifact_path},
python_model=self._ml_flow.pyfunc.PythonModel(),
)

def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
Expand Down