From fe0992e50597481c2de71fa55e927033d9c58c9f Mon Sep 17 00:00:00 2001 From: Neil Movva Date: Sun, 8 Sep 2024 19:01:48 -0700 Subject: [PATCH] Add Torch CUDA sync to fix timing code in cli.py --- src/flux/cli.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/flux/cli.py b/src/flux/cli.py index 9e18f6e9..b961b848 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -215,6 +215,9 @@ def main( x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.perf_counter() fn = output_name.format(idx=idx)