diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 8e5184f2a4475..6eeb9ab03f602 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -817,19 +817,37 @@ def update_parametrize_target_arg( ): args = [arg.strip() for arg in argnames.split(",") if arg.strip()] if "target" in args: - if len(args) == 1: - targets = argvalues - param_sets = [(target,) for target in targets] - else: - target_i = args.index("target") - targets = [param_set[target_i] for param_set in argvalues] - param_sets = argvalues + target_i = args.index("target") + + new_argvalues = [] + for argvalue in argvalues: + + if isinstance(argvalue, _pytest.mark.structures.ParameterSet): + # The parametrized value is already a + # pytest.param, so track any marks already + # defined. + param_set = argvalue.values + target = param_set[target_i] + additional_marks = argvalue.marks + elif len(args) == 1: + # Single value parametrization, argvalue is a list of values. + target = argvalue + param_set = (target,) + additional_marks = [] + else: + # Multiple correlated parameters, argvalue is a list of tuple of values. + param_set = argvalue + target = param_set[target_i] + additional_marks = [] + + new_argvalues.append( + pytest.param( + *param_set, marks=_target_to_requirement(target) + additional_marks + ) + ) try: - argvalues[:] = [ - pytest.param(*param_set, marks=_target_to_requirement(target)) - for target, param_set in zip(targets, param_sets) - ] + argvalues[:] = new_argvalues except TypeError as e: pyfunc = metafunc.definition.function filename = pyfunc.__code__.co_filename