diff --git a/src/dfcx_scrapi/tools/evaluations.py b/src/dfcx_scrapi/tools/evaluations.py index 7864a784..0176eee6 100644 --- a/src/dfcx_scrapi/tools/evaluations.py +++ b/src/dfcx_scrapi/tools/evaluations.py @@ -83,8 +83,9 @@ def __init__( self.generation_model = self.model_setup(generation_model) self.embedding_model = self.model_setup(embedding_model) + self.user_input_metrics = metrics self.metrics = build_metrics( - metrics=metrics, + metrics=self.user_input_metrics, generation_model=self.generation_model, embedding_model=self.embedding_model ) @@ -189,7 +190,7 @@ def add_response_columns(self, df: pd.DataFrame) -> pd.DataFrame: df.loc[:, "session_id"] = pd.Series(dtype="str") df.loc[:, "res_playbook_name"] = pd.Series(dtype="str") - if "tool_call_quality" in self.metrics: + if "tool_call_quality" in self.user_input_metrics: df.loc[:, "res_tool_name"] = pd.Series(dtype="str") df.loc[:, "res_tool_action"] = pd.Series(dtype="str") df.loc[:, "res_input_params"] = pd.Series(dtype="str") @@ -243,7 +244,7 @@ def run_detect_intent_queries(self, df: pd.DataFrame) -> pd.DataFrame: ) # Handle Tool Invocations - if "tool_call_quality" in self.metrics: + if "tool_call_quality" in self.user_input_metrics: tool_responses = self.s.collect_tool_responses(res) if len(tool_responses) > 0: df = self.process_tool_invocations( diff --git a/tests/dfcx_scrapi/core/test_agents.py b/tests/dfcx_scrapi/core/test_agents.py index 2eec46a0..c1d832dc 100644 --- a/tests/dfcx_scrapi/core/test_agents.py +++ b/tests/dfcx_scrapi/core/test_agents.py @@ -1,7 +1,5 @@ """Test Class for Agent Methods in SCRAPI.""" -# pylint: disable=redefined-outer-name - # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License");