diff --git a/axlearn/open_api/metrics/tool_use_execution_utils.py b/axlearn/open_api/metrics/tool_use_execution_utils.py index a37d88238..3e8cd5130 100644 --- a/axlearn/open_api/metrics/tool_use_execution_utils.py +++ b/axlearn/open_api/metrics/tool_use_execution_utils.py @@ -298,7 +298,7 @@ def _is_arg_value_equal( ) return pred_lenient == target_lenient - return False + return pred_arg == target_arg def check_arguments( @@ -317,6 +317,9 @@ def check_arguments( Returns: True if the predicted and targets arguments are matching according to the flags. """ + if not isinstance(pred_args, dict) or not isinstance(target_args, dict): + return False + # Check names are not duplicated. target_args_copy = dict(target_args.items()) @@ -330,7 +333,6 @@ def check_arguments( target_args_copy.pop(pred_arg_name) else: return False - # If there are still elements in to_kwargs, to_kwargs contains more entries than from_kwargs # and the arguments are not matching. return len(target_args_copy) == 0 diff --git a/axlearn/open_api/metrics/tool_use_execution_utils_test.py b/axlearn/open_api/metrics/tool_use_execution_utils_test.py index b13160405..1bed4a8a4 100644 --- a/axlearn/open_api/metrics/tool_use_execution_utils_test.py +++ b/axlearn/open_api/metrics/tool_use_execution_utils_test.py @@ -55,6 +55,22 @@ def test_all_positive_matches(self, pred, target): lenient=False, strict=False, ), + # non-string argument values. + dict( + pred={"soundType": "nature", "intensity": "medium", "duration": 45}, + target={"soundType": "nature", "intensity": "medium", "duration": 45}, + lenient_bow=True, + lenient=True, + strict=True, + ), + # non-dict arguments. + dict( + pred=[{"soundType": "nature"}, {"intensity": "medium", "duration": 45}], + target={"soundType": "nature", "intensity": "medium", "duration": 45}, + lenient_bow=False, + lenient=False, + strict=False, + ), ) def test_all_matches(self, pred, target, lenient_bow, lenient, strict): self.assertEqual(