From 87f6fff727a377ea1c378af692afb41ae84cbe04 Mon Sep 17 00:00:00 2001 From: Neil Movva Date: Fri, 13 Sep 2024 02:47:16 -0700 Subject: [PATCH] Add Torch CUDA sync to fix timing code in cli.py (#147) Co-authored-by: Neil Movva --- src/flux/cli.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/flux/cli.py b/src/flux/cli.py index 9e18f6e9..b961b848 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -215,6 +215,9 @@ def main( x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.perf_counter() fn = output_name.format(idx=idx)