Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
Signed-off-by: xq-yin <[email protected]>
  • Loading branch information
xq-yin committed Jun 26, 2024
1 parent 4abe77f commit 52737cb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mlflow/models/evaluation/default_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def _evaluate(
self.extra_metrics.remove(extra_metric)
# When the field is present, the metric is created from either make_genai_metric
# or make_genai_metric_from_prompt. We will log the metric definition.
if hasattr(extra_metric, "custom_metric_config"):
if extra_metric.custom_metric_config is not None:
genai_custom_metrics.append(extra_metric.custom_metric_config)
self._generate_model_predictions(compute_latency=compute_latency)
self._handle_builtin_metrics_by_model_type()
Expand Down
33 changes: 32 additions & 1 deletion tests/evaluate/test_default_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from mlflow.exceptions import MlflowException
from mlflow.metrics import (
MetricValue,
flesch_kincaid_grade_level,
make_metric,
toxicity,
)
from mlflow.metrics.genai import model_utils
from mlflow.metrics.genai.base import EvaluationExample
Expand Down Expand Up @@ -4134,6 +4136,34 @@ def word_count_eval(predictions):
)


def test_do_not_log_built_in_metrics_as_artifacts():
with mlflow.start_run() as run:
model_info = mlflow.pyfunc.log_model(
artifact_path="model", python_model=language_model, input_example=["a"]
)
data = pd.DataFrame(
{
"inputs": ["words random", "This is a sentence."],
"ground_truth": ["words random", "This is a sentence."],
}
)
evaluate(
model_info.model_uri,
data,
targets="ground_truth",
predictions="answer",
model_type="question-answering",
evaluators="default",
extra_metrics=[
toxicity(),
flesch_kincaid_grade_level(),
],
)
client = mlflow.MlflowClient()
artifacts = [a.path for a in client.list_artifacts(run.info.run_id)]
assert _GENAI_CUSTOM_METRICS_FILE_NAME not in artifacts


def test_log_llm_custom_metrics_as_artifacts():
with mlflow.start_run() as run:
model_info = mlflow.pyfunc.log_model(
Expand Down Expand Up @@ -4180,7 +4210,8 @@ def test_log_llm_custom_metrics_as_artifacts():
table = result.tables[_GENAI_CUSTOM_METRICS_FILE_NAME.split(".", 1)[0]]
assert table.loc[0, "name"] == "answer_similarity"
assert table.loc[0, "version"] == "v1"
assert table.loc[0, "metric_config"] is not None
assert table.loc[1, "name"] == "custom llm judge"
assert table.loc[1, "version"] is None
# TODO(xq-yin) ML-41356: Validate metric_config value once we implement deser function
assert table.loc[0, "metric_config"] is not None
assert table.loc[1, "metric_config"] is not None

0 comments on commit 52737cb

Please sign in to comment.