diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 3417ee8fcaa17b..bc5baeaf35b6bd 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -878,6 +878,17 @@ def on_train_end(self, args, state, control, **kwargs): if self._auto_end_run and self._ml_flow.active_run(): self._ml_flow.end_run() + def on_save(self, args, state, control, **kwargs): + if self._initialized and state.is_world_process_zero and self._log_artifacts: + ckpt_dir = f"checkpoint-{state.global_step}" + artifact_path = os.path.join(args.output_dir, ckpt_dir) + + self._ml_flow.pyfunc.log_model( + ckpt_dir, + artifacts={"model_path": artifact_path}, + python_model=self._ml_flow.pyfunc.PythonModel(), + ) + def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed