diff --git a/predict.py b/predict.py index de94eb5..840855d 100644 --- a/predict.py +++ b/predict.py @@ -145,7 +145,10 @@ def base_setup( self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME) # need > 48 GB of ram to store all models in VRAM - self.offload = "A40" in gpu_name + total_mem = torch.cuda.get_device_properties(0).total_memory + self.offload = total_mem < 48 * 1024**3 + if self.offload: + print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models") device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512