diff --git a/fp8/float8_quantize.py b/fp8/float8_quantize.py index 3e48e91..400edc0 100644 --- a/fp8/float8_quantize.py +++ b/fp8/float8_quantize.py @@ -275,16 +275,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: prev_dims = x.shape[:-1] x = x.view(-1, self.in_features) - # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! - out = torch._scaled_mm( # noqa - x, - self.float8_data.T, - scale_a=self.input_scale_reciprocal, - scale_b=self.scale_reciprocal, - bias=self.bias, - out_dtype=self.weight.dtype, - use_fast_accum=True, - ) + device = x.device + if x.device.type != 'cpu' and torch.cuda.get_device_capability(x.device) >= (8, 9): + # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! + out = torch._scaled_mm( # noqa + x, + self.float8_data.T, + scale_a=self.input_scale_reciprocal, + scale_b=self.scale_reciprocal, + bias=self.bias, + out_dtype=self.weight.dtype, + use_fast_accum=True, + ) + else: + # Plain matrix multiplication for non-ADA devices + # Assuming x is in float8 and self.float8_data is in float8 as well + # Convert to float32, perform the multiplication, and then apply scaling and bias if necessary + + # Convert float8 to float32 for the multiplication + x_float32 = x.to(torch.float32) + float8_data_float32 = self.float8_data.T.to(torch.float32) + + # Regular matrix multiplication + out = torch.matmul(x_float32, float8_data_float32) + + # Scale the output accordingly + out = out * (self.input_scale_reciprocal * self.scale_reciprocal) + + # Add bias if it exists + if self.bias is not None: + out += self.bias + out = out.to(self.weight.dtype) + if IS_TORCH_2_4: out = out[0] return out.view(*prev_dims, self.out_features) diff --git a/predict.py b/predict.py index 4132dba..d0d8088 100644 --- a/predict.py +++ b/predict.py @@ -165,7 +165,11 @@ def base_setup( self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME) # need > 48 GB of ram to store all models in VRAM - self.offload = "A40" in gpu_name + total_mem = torch.cuda.get_device_properties(0).total_memory + self.offload = total_mem < 48 * 1024**3 + if self.offload: + print("GPU memory is:", total_mem / 1024**3, ", offloading models") + compile_fp8 = False device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512 @@ -187,13 +191,32 @@ def base_setup( flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None ) - # fp8 only works w/compute capability >= 8.9 - self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9) + self.disable_fp8 = disable_fp8 if not self.disable_fp8: + if compile_fp8: + extra_args = { + "compile_whole_model": True, + "compile_extras": True, + "compile_blocks": True, + } + else: + extra_args = { + "compile_whole_model": False, + "compile_extras": False, + "compile_blocks": False, + } + + if self.offload: + extra_args |= { + "offload_text_encoder": True, + "offload_vae": True, + "offload_flow": True, + } self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( f"fp8/configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models, + **extra_args, ) if compile_fp8: