diff --git a/src/flux/cli.py b/src/flux/cli.py index 9e18f6e9..bad807e7 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -215,6 +215,7 @@ def main( x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) + torch.cuda.synchronize() t1 = time.perf_counter() fn = output_name.format(idx=idx)