Skip to content

Commit

Permalink
WIP: Support table logging for mlflow, too
Browse files Browse the repository at this point in the history
Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.

In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".

See #1505
  • Loading branch information
Dave Farago committed Apr 9, 2024
1 parent 4313b1a commit 2b87bde
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
7 changes: 6 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,12 @@ def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))

Expand Down
39 changes: 20 additions & 19 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Dict, List

import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
Expand All @@ -28,6 +29,7 @@
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
Expand Down Expand Up @@ -540,7 +542,7 @@ def predict_with_generate():
return CausalLMBenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

Expand Down Expand Up @@ -597,15 +599,13 @@ def find_ranges(lst):
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
table_data = {
"id": [],
"Prompt": [],
"Correct Completion": [],
"Predicted Completion (model.generate)": [],
"Predicted Completion (trainer.prediction_step)": [],
}
row_index = 0

for batch in tqdm(table_dataloader):
Expand Down Expand Up @@ -709,16 +709,17 @@ def log_table_from_dataloader(name: str, table_dataloader):
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(prediction_text)
table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow":
tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri
mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri)

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
Expand Down

0 comments on commit 2b87bde

Please sign in to comment.