Skip to content

Commit

Permalink
[WIP] DeepFloyd: stage1 can use 8bit text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Jul 22, 2023
1 parent 1e8d7ed commit 5d295ac
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions discord_tron_client/classes/image_manipulation/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
use_upscaler,
)
from PIL import Image
import torch, gc, logging, diffusers
import torch, gc, logging, diffusers, transformers

torch.backends.cudnn.deterministic = False
torch.backends.cuda.matmul.allow_tf32 = False
Expand Down Expand Up @@ -88,7 +88,7 @@ def clear_pipeline(self, model_id: str) -> None:
except Exception as e:
logging.error(f"Error when deleting pipe: {e}")

def create_pipeline(self, model_id: str, pipe_type: str, use_safetensors: bool = True, custom_text_encoder: int = None, safety_modules: dict = None) -> Pipeline:
def create_pipeline(self, model_id: str, pipe_type: str, use_safetensors: bool = True, custom_text_encoder = None, safety_modules: dict = None) -> Pipeline:
pipeline_class = self.PIPELINE_CLASSES[pipe_type]
extra_args = {
'feature_extractor': None,
Expand All @@ -98,6 +98,9 @@ def create_pipeline(self, model_id: str, pipe_type: str, use_safetensors: bool =
if custom_text_encoder is not None and custom_text_encoder == -1:
# Disable text encoder.
extra_args["text_encoder"] = None
elif custom_text_encoder is not None:
# Use a custom text encoder.
extra_args["text_encoder"] = custom_text_encoder
if safety_modules is not None:
for key in safety_modules:
extra_args[key] = safety_modules[key]
Expand Down Expand Up @@ -177,7 +180,7 @@ def get_pipe(
prompt_variation: bool = False,
promptless_variation: bool = False,
upscaler: bool = False,
custom_text_encoder: int = None,
custom_text_encoder = None,
safety_modules: dict = None
) -> Pipeline:
self.delete_pipes(keep_model=model_id)
Expand Down Expand Up @@ -218,7 +221,13 @@ def get_pipe(
self.clear_pipeline(model_id)

if model_id not in self.pipelines:
logging.debug(f"Creating pipeline type {pipe_type} for model {model_id} with custom_text_encoder {custom_text_encoder}")
if "DeepFloyd/IF-I-" in model_id:
# DeepFloyd stage 1 can use a more efficient text encoder config.
custom_text_encoder = transformers.T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
)

logging.debug(f"Creating pipeline type {pipe_type} for model {model_id} with custom_text_encoder {type(custom_text_encoder)}")
self.pipelines[model_id] = self.create_pipeline(model_id, pipe_type, use_safetensors=use_safetensors, custom_text_encoder=custom_text_encoder, safety_modules=safety_modules)
if pipe_type in ["upscaler", "prompt_variation", "text2img", "kandinsky-2.2"]:
pass
Expand Down Expand Up @@ -261,6 +270,12 @@ def get_pipe(
mode="reduce-overhead",
fullgraph=True,
)
if config.enable_compile() and hasattr(self.pipelines[model_id], 'text_encoder'):
self.pipelines[model_id].text_encoder = torch.compile(
self.pipelines[model_id].text_encoder,
mode="reduce-overhead",
fullgraph=True,
)
else:
logging.info(f"Keeping existing pipeline. Not creating any new ones.")
self.pipelines[model_id].to(self.device)
Expand Down

0 comments on commit 5d295ac

Please sign in to comment.