Skip to content

Commit

Permalink
[model cards] Keep evaluation order in training logs if there's mul…
Browse files Browse the repository at this point in the history
…tiple evaluators (#2963)

Also rename "loss" to Validation Loss
  • Loading branch information
tomaarsen authored Sep 30, 2024
1 parent 4f43de6 commit a7cc68f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def on_evaluate(
**kwargs,
) -> None:
loss_dict = {" ".join(key.split("_")[1:]): metrics[key] for key in metrics if key.endswith("_loss")}
if len(loss_dict) == 1 and "loss" in loss_dict:
loss_dict = {"Validation Loss": loss_dict["loss"]}
if (
model.model_card_data.training_logs
and model.model_card_data.training_logs[-1]["Step"] == state.global_step
Expand Down Expand Up @@ -830,19 +832,25 @@ def try_to_pure_python(value: Any) -> Any:

def format_training_logs(self):
# Get the keys from all evaluation lines
eval_lines_keys = {key for lines in self.training_logs for key in lines.keys()}
eval_lines_keys = []
for lines in self.training_logs:
for key in lines.keys():
if key not in eval_lines_keys:
eval_lines_keys.append(key)

# Sort the metric columns: Epoch, Step, Training Loss, Validation Loss, Evaluator results
def sort_metrics(key: str) -> str:
if key == "Epoch":
return "0"
return 0
if key == "Step":
return "1"
return 1
if key == "Training Loss":
return "2"
return 2
if key == "Validation Loss":
return 3
if key.endswith("loss"):
return "3"
return key
return 4
return eval_lines_keys.index(key) + 5

sorted_eval_lines_keys = sorted(eval_lines_keys, key=sort_metrics)
training_logs = [
Expand Down

0 comments on commit a7cc68f

Please sign in to comment.