diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py index ee7fb48d2..fc4cc3603 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py @@ -95,9 +95,7 @@ def collect_results( results.append(buffers) for example in example_inputs: if isinstance(example, (tuple, list)): - for inp in example: - if isinstance(inp, torch.Tensor): - results.append(inp.grad) + results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor)) else: if isinstance(example, torch.Tensor): results.append(example.grad) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 345abcd76..9f2c7b0c9 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1535,9 +1535,10 @@ def checkpoint_params(gm): rng_state = torch.clone(torch.random.get_rng_state()) if torch.cuda.is_available(): cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) - saved_state = [] - for param in itertools.chain(gm.parameters(), gm.buffers()): - saved_state.append((param, param._version, torch.clone(param))) + saved_state = [ + (param, param._version, torch.clone(param)) + for param in itertools.chain(gm.parameters(), gm.buffers()) + ] def restore(): with torch.no_grad():