Skip to content

Commit

Permalink
Merge pull request #1167 from bghira/bugfix/text-encoder-quantisation
Browse files Browse the repository at this point in the history
quantise text encoders upon request correctly
  • Loading branch information
bghira authored Nov 17, 2024
2 parents 1e0b8d5 + 9d338ec commit 6b4b1e8
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 61 deletions.
147 changes: 109 additions & 38 deletions helpers/training/quantisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,56 +174,127 @@ def _torchao_model(
)

else:
raise ValueError(f"Invalid quantisation level: {base_model_precision}")
raise ValueError(
f"Invalid quantisation level. model_precision={model_precision}, base_model_precision={base_model_precision}"
)

return model


def get_quant_fn(base_model_precision):
"""
Determine the quantization function based on the base model precision.
Args:
base_model_precision (str): The precision specification for the base model.
Returns:
function: The corresponding quantization function.
Raises:
ValueError: If the precision specification is unsupported.
"""
precision = base_model_precision.lower()
if precision == "no_change":
return None
if "quanto" in precision:
return _quanto_model
elif "torchao" in precision:
return _torchao_model
else:
return None


def quantise_model(
unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet, args
):
if "quanto" in args.base_model_precision.lower():
logger.info("Loading Quanto. This may take a few minutes.")
quant_fn = _quanto_model
elif "torchao" in args.base_model_precision.lower():
logger.info("Loading TorchAO. This may take a few minutes.")
quant_fn = _torchao_model
if transformer is not None:
transformer = quant_fn(
"""
Quantizes the provided models using the specified precision settings.
Args:
unet: The UNet model to quantize.
transformer: The Transformer model to quantize.
text_encoder_1: The first text encoder to quantize.
text_encoder_2: The second text encoder to quantize.
text_encoder_3: The third text encoder to quantize.
controlnet: The ControlNet model to quantize.
args: An object containing precision settings and other arguments.
Returns:
tuple: A tuple containing the quantized models in the order:
(unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet)
"""
models = [
(
transformer,
model_precision=args.base_model_precision,
quantize_activations=args.quantize_activations,
)
if unet is not None:
unet = quant_fn(
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": args.quantize_activations,
},
),
(
unet,
model_precision=args.base_model_precision,
quantize_activations=args.quantize_activations,
)
if controlnet is not None:
controlnet = quant_fn(
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": args.quantize_activations,
},
),
(
controlnet,
model_precision=args.base_model_precision,
quantize_activations=args.quantize_activations,
)

if text_encoder_1 is not None:
text_encoder_1 = quant_fn(
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": args.quantize_activations,
},
),
(
text_encoder_1,
model_precision=args.text_encoder_1_precision,
base_model_precision=args.base_model_precision,
)
if text_encoder_2 is not None:
text_encoder_2 = quant_fn(
{
"quant_fn": get_quant_fn(args.text_encoder_1_precision),
"model_precision": args.text_encoder_1_precision,
"base_model_precision": args.base_model_precision,
},
),
(
text_encoder_2,
model_precision=args.text_encoder_2_precision,
base_model_precision=args.base_model_precision,
)
if text_encoder_3 is not None:
text_encoder_3 = quant_fn(
{
"quant_fn": get_quant_fn(args.text_encoder_2_precision),
"model_precision": args.text_encoder_2_precision,
"base_model_precision": args.base_model_precision,
},
),
(
text_encoder_3,
model_precision=args.text_encoder_3_precision,
base_model_precision=args.base_model_precision,
)
{
"quant_fn": get_quant_fn(args.text_encoder_3_precision),
"model_precision": args.text_encoder_3_precision,
"base_model_precision": args.base_model_precision,
},
),
]

# Iterate over the models and apply quantization if the model is not None
for i, (model, quant_args) in enumerate(models):
quant_fn = quant_args["quant_fn"]
if quant_fn is None:
continue
if model is not None:
quant_args_combined = {
"model_precision": quant_args["model_precision"],
"base_model_precision": quant_args.get(
"base_model_precision", args.base_model_precision
),
"quantize_activations": quant_args.get(
"quantize_activations", args.quantize_activations
),
}
models[i] = (quant_fn(model, **quant_args_combined), quant_args)

# Unpack the quantized models
transformer, unet, controlnet, text_encoder_1, text_encoder_2, text_encoder_3 = [
model for model, _ in models
]

return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet
63 changes: 40 additions & 23 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,14 @@ def init_vae(self, move_to_accelerator: bool = True):
)
self.config.vae_kwargs["subfolder"] = None
self.vae = AutoencoderKL.from_pretrained(**self.config.vae_kwargs)
if self.vae is not None and self.config.vae_enable_tiling and hasattr(self.vae, 'enable_tiling'):
logger.warning("Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents.")
if (
self.vae is not None
and self.config.vae_enable_tiling
and hasattr(self.vae, "enable_tiling")
):
logger.warning(
"Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents."
)
self.vae.enable_tiling()
if not move_to_accelerator:
logger.debug("Not moving VAE to accelerator.")
Expand Down Expand Up @@ -750,7 +756,7 @@ def init_unload_text_encoder(self):
" The real memories were the friends we trained a model on along the way."
)

def init_precision(self):
def init_precision(self, preprocessing_models_only: bool = False):
self.config.enable_adamw_bf16 = (
True if self.config.weight_dtype == torch.bfloat16 else False
)
Expand All @@ -769,24 +775,29 @@ def init_precision(self):
elif self.config.base_model_default_dtype == "bf16":
self.config.base_weight_dtype = torch.bfloat16
self.config.enable_adamw_bf16 = True
if self.unet is not None:
logger.info(
f"Moving U-net to dtype={self.config.base_weight_dtype}, device={quantization_device}"
)
self.unet.to(quantization_device, dtype=self.config.base_weight_dtype)
elif self.transformer is not None:
logger.info(
f"Moving transformer to dtype={self.config.base_weight_dtype}, device={quantization_device}"
)
self.transformer.to(
quantization_device, dtype=self.config.base_weight_dtype
)
if not preprocessing_models_only:
if self.unet is not None:
logger.info(
f"Moving U-net to dtype={self.config.base_weight_dtype}, device={quantization_device}"
)
self.unet.to(
quantization_device, dtype=self.config.base_weight_dtype
)
elif self.transformer is not None:
logger.info(
f"Moving transformer to dtype={self.config.base_weight_dtype}, device={quantization_device}"
)
self.transformer.to(
quantization_device, dtype=self.config.base_weight_dtype
)

if self.config.is_quanto:
with self.accelerator.local_main_process_first():
self.quantise_model(
unet=self.unet,
transformer=self.transformer,
unet=self.unet if not preprocessing_models_only else None,
transformer=(
self.transformer if not preprocessing_models_only else None
),
text_encoder_1=self.text_encoder_1,
text_encoder_2=self.text_encoder_2,
text_encoder_3=self.text_encoder_3,
Expand All @@ -803,8 +814,10 @@ def init_precision(self):
self.text_encoder_3,
self.controlnet,
) = self.quantise_model(
unet=self.unet,
transformer=self.transformer,
unet=self.unet if not preprocessing_models_only else None,
transformer=(
self.transformer if not preprocessing_models_only else None
),
text_encoder_1=self.text_encoder_1,
text_encoder_2=self.text_encoder_2,
text_encoder_3=self.text_encoder_3,
Expand Down Expand Up @@ -1376,7 +1389,7 @@ def init_validations(self):
ema_model=self.ema_model,
vae=self.vae,
controlnet=self.controlnet if self.config.controlnet else None,
model_evaluator=model_evaluator
model_evaluator=model_evaluator,
)
if not self.config.train_text_encoder and self.validation is not None:
self.validation.clear_text_encoders()
Expand Down Expand Up @@ -2592,13 +2605,15 @@ def train(self):
self.guidance_values_list = []
if grad_norm is not None:
wandb_logs["grad_norm"] = grad_norm
if self.validation is not None and hasattr(self.validation, 'evaluation_result'):
if self.validation is not None and hasattr(
self.validation, "evaluation_result"
):
eval_result = self.validation.get_eval_result()
if eval_result is not None and type(eval_result) == dict:
# add the dict to wandb_logs
self.validation.clear_eval_result()
wandb_logs.update(eval_result)

progress_bar.update(1)
self.state["global_step"] += 1
current_epoch_step += 1
Expand Down Expand Up @@ -2717,7 +2732,9 @@ def train(self):
self.config.output_dir, removing_checkpoint
)
try:
shutil.rmtree(removing_checkpoint, ignore_errors=True)
shutil.rmtree(
removing_checkpoint, ignore_errors=True
)
except Exception as e:
logger.error(
f"Failed to remove directory: {removing_checkpoint}"
Expand Down

0 comments on commit 6b4b1e8

Please sign in to comment.